diff --git a/CHANGELOG.md b/CHANGELOG.md index 229921847..4912e00e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- feat(example): add OpenAI-compatible embeddings endpoint by @abetlen in #2281 + ## [0.3.27] - feat: update llama.cpp to ggml-org/llama.cpp@465b1f0e7 diff --git a/examples/server/README.md b/examples/server/README.md index 1f6e0d3db..6a1cc4448 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -1,7 +1,7 @@ # Server Example This example is an updated OpenAI-compatible web server that depends only on the low-level C bindings. -It supports batched inference, prompt caching, response parsing, `/v1/responses`, disk sequence caching, MTP, LoRA, and multimodal image/audio inputs. +It supports batched inference, prompt caching, response parsing, `/v1/responses`, `/v1/embeddings`, disk sequence caching, MTP, LoRA, and multimodal image/audio inputs. ## Setup @@ -46,6 +46,7 @@ The smallest checked-in example uses Qwen3.5 0.8B so the server can be started o | Config | Model | Notes | | --- | --- | --- | +| [`configs/bge-small-en-v1.5.json`](configs/bge-small-en-v1.5.json) | [`CompendiumLabs/bge-small-en-v1.5-gguf`](https://huggingface.co/CompendiumLabs/bge-small-en-v1.5-gguf) | Small embedding model config for `/v1/embeddings`. | | [`configs/qwen3.5-0.8b.json`](configs/qwen3.5-0.8b.json) | [`lmstudio-community/Qwen3.5-0.8B-GGUF`](https://huggingface.co/lmstudio-community/Qwen3.5-0.8B-GGUF) | Default small multimodal example. | | [`configs/gemma-4-12b-it-qat.json`](configs/gemma-4-12b-it-qat.json) | [`unsloth/gemma-4-12B-it-qat-GGUF`](https://huggingface.co/unsloth/gemma-4-12B-it-qat-GGUF) | Larger Gemma 4 QAT multimodal config with projector. | | [`configs/qwen3.6-27b.json`](configs/qwen3.6-27b.json) | [`unsloth/Qwen3.6-27B-GGUF`](https://huggingface.co/unsloth/Qwen3.6-27B-GGUF) | Larger Qwen3.6 multimodal config. | @@ -86,11 +87,33 @@ response = client.responses.create( print(response.output_text) ``` +### Embeddings + +Start the server with an embedding config before calling `/v1/embeddings`. + +```bash +cd examples/server +uv run --script server.py -C configs/bge-small-en-v1.5.json +``` + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://127.0.0.1:8000/v1", api_key="not-used") + +response = client.embeddings.create( + model="bge-small-en-v1.5", + input=["The food was delicious.", "The meal was excellent."], +) +print(len(response.data[0].embedding)) +``` + ## API Surface | Endpoint | Purpose | Reference | | --- | --- | --- | | `POST /v1/completions` | Legacy text completions with streaming, stop sequences, logprobs, penalties, seeds, and grammar-backed JSON output. | [OpenAI Completions API](https://platform.openai.com/docs/api-reference/completions) | +| `POST /v1/embeddings` | OpenAI-compatible embeddings for embedding-mode GGUF models, including string inputs, token inputs, base64 output, and dimensions truncation. | [OpenAI Embeddings API](https://platform.openai.com/docs/api-reference/embeddings) | | `POST /v1/chat/completions` | Chat completions with streaming, tools, forced tool choice, reasoning parsing, multimodal content parts, and structured response parsing. | [OpenAI Chat API](https://platform.openai.com/docs/api-reference/chat) | | `POST /v1/responses` | Stateless Responses API compatibility for clients that use response items and response events. | [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) | | `WS /v1/responses` | Stateful websocket Responses transport with per-connection `previous_response_id` replay. | [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) | @@ -190,6 +213,8 @@ Most model runtime fields map to `llama_model_params` or `llama_context_params` | `threads` | Decode thread count. | | `threads_batch` | Prefill and batch thread count. | | `kv_unified` | Selects unified or per-sequence memory layout. | +| `embedding` | Overrides embedding mode; omit to auto-detect pooled embedding GGUFs from model metadata. | +| `pooling_type` | Overrides pooled embedding behavior for embedding models, such as `1` for mean pooling. | | `store_logits` | Keeps logits after decode when needed by sampling or diagnostics. | | `use_mmap` | Memory maps model weights. | | `use_mlock` | Attempts to lock model pages into RAM. | diff --git a/examples/server/configs/bge-small-en-v1.5.json b/examples/server/configs/bge-small-en-v1.5.json new file mode 100644 index 000000000..3fc8016df --- /dev/null +++ b/examples/server/configs/bge-small-en-v1.5.json @@ -0,0 +1,22 @@ +{ + "server": { + "host": "0.0.0.0", + "port": 8000 + }, + "model": { + "alias": "bge-small-en-v1.5", + "from_pretrained": { + "repo_id": "CompendiumLabs/bge-small-en-v1.5-gguf", + "filename": "bge-small-en-v1.5-q4_k_m.gguf" + }, + "n_ctx": 512, + "n_seq_max": 16, + "n_batch": 512, + "n_ubatch": 512, + "threads": 4, + "threads_batch": 8, + "kv_unified": true, + "store_logits": false, + "use_mmap": true + } +} diff --git a/examples/server/server.py b/examples/server/server.py index 590a21420..2530cf4be 100644 --- a/examples/server/server.py +++ b/examples/server/server.py @@ -2457,6 +2457,7 @@ class CompletionChunk(TypedDict): CompletionStream = Generator[CompletionChunk, None, OpenAICompletion] CompletionPrompt = Union[str, List[int], List[str], List[List[int]]] +EmbeddingInput = Union[str, List[str], List[int], List[List[int]]] class CreateCompletionRequest(BaseModel): @@ -2522,6 +2523,126 @@ def normalized_prompt(self) -> List[Union[str, List[int]]]: raise ValueError("prompt must be a string, token ids, list of strings, or list of token-id lists") +class CreateEmbeddingRequest(BaseModel): + model_config = ConfigDict(extra="ignore") + + input: EmbeddingInput + model: str + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = Field(default=None, ge=1) + user: Optional[str] = None + + @staticmethod + def _validate_text_input(text: str) -> str: + if text == "": + raise ValueError("embedding input must not contain empty strings") + return text + + @staticmethod + def _validate_token_input(tokens: List[int]) -> List[int]: + if not tokens: + raise ValueError("embedding token input must not be empty") + if len(tokens) > 2048: + raise ValueError("embedding token input must not exceed 2048 tokens") + return tokens + + @model_validator(mode="after") + def validate_after(self) -> "CreateEmbeddingRequest": + self.normalized_input() + return self + + def normalized_input(self) -> List[Union[str, List[int]]]: + if isinstance(self.input, str): + return [self._validate_text_input(self.input)] + if all(isinstance(token, int) for token in self.input): + return [self._validate_token_input(cast(List[int], self.input))] + if all(isinstance(item, str) for item in self.input): + if len(self.input) > 2048: + raise ValueError("embedding input array must not exceed 2048 items") + return [ + self._validate_text_input(item) + for item in cast(List[str], self.input) + ] + if all( + isinstance(item, list) + and all(isinstance(token, int) for token in item) + for item in self.input + ): + if len(self.input) > 2048: + raise ValueError("embedding input array must not exceed 2048 items") + return [ + self._validate_token_input(item) + for item in cast(List[List[int]], self.input) + ] + raise ValueError( + "embedding input must be a string, list of strings, token ids, or list of token-id lists" + ) + + +class EmbeddingDataResponse(BaseModel): + object: Literal["embedding"] = "embedding" + embedding: Union[List[float], str] + index: int + + +class EmbeddingUsageResponse(BaseModel): + prompt_tokens: int + total_tokens: int + + +class CreateEmbeddingResponse(BaseModel): + object: Literal["list"] = "list" + data: List[EmbeddingDataResponse] + model: str + usage: EmbeddingUsageResponse + + @staticmethod + def encode_embedding( + embedding: Sequence[float], + encoding_format: Literal["float", "base64"], + dimensions: Optional[int], + ) -> Union[List[float], str]: + if dimensions is not None: + if dimensions > len(embedding): + raise CompletionRequestValidationError( + f"dimensions ({dimensions}) exceeds embedding size ({len(embedding)})" + ) + embedding = embedding[:dimensions] + if encoding_format == "float": + return [float(value) for value in embedding] + array = np.asarray(embedding, dtype=np.float32) + return base64.b64encode(array.tobytes()).decode("ascii") + + @classmethod + def from_embeddings( + cls, + *, + model: str, + embeddings: Sequence[Sequence[float]], + total_tokens: int, + encoding_format: Literal["float", "base64"], + dimensions: Optional[int], + ) -> "CreateEmbeddingResponse": + return cls( + data=[ + EmbeddingDataResponse( + embedding=cls.encode_embedding( + embedding, + encoding_format, + dimensions, + ), + index=index, + ) + for index, embedding in enumerate(embeddings) + ], + model=model, + usage=EmbeddingUsageResponse( + prompt_tokens=total_tokens, + total_tokens=total_tokens, + ), + ) + + class ChatCompletionFunctionCall(BaseModel): model_config = ConfigDict( extra="allow", @@ -3214,6 +3335,7 @@ class ModelOptions(BaseModel): rope_scaling_type: Optional[int] = None pooling_type: Optional[int] = None attention_type: Optional[int] = None + embedding: Optional[bool] = None rope_freq_base: Optional[float] = None rope_freq_scale: Optional[float] = None yarn_ext_factor: Optional[float] = None @@ -10408,6 +10530,7 @@ def __init__( rope_scaling_type: Optional[int] = None, pooling_type: Optional[int] = None, attention_type: Optional[int] = None, + embedding: Optional[bool] = None, rope_freq_base: Optional[float] = None, rope_freq_scale: Optional[float] = None, yarn_ext_factor: Optional[float] = None, @@ -10473,9 +10596,13 @@ def __init__( if vocab is None: raise RuntimeError("failed to access model vocabulary") self.vocab = vocab - if llama_cpp.llama_model_has_encoder(llama_model): + embedding = self.resolve_embedding_mode(llama_model, embedding) + self.embedding = embedding + self.has_encoder = bool(llama_cpp.llama_model_has_encoder(llama_model)) + self.has_decoder = bool(llama_cpp.llama_model_has_decoder(llama_model)) + if self.has_encoder and not embedding: raise RuntimeError("encoder models are not supported") - if not llama_cpp.llama_model_has_decoder(llama_model): + if not self.has_decoder and not (embedding and self.has_encoder): raise RuntimeError("decoder is required") if llama_cpp.llama_model_is_recurrent(llama_model): self.memory_model = "recurrent" @@ -10519,6 +10646,7 @@ def __init__( rope_scaling_type=rope_scaling_type, pooling_type=pooling_type, attention_type=attention_type, + embedding=embedding, rope_freq_base=rope_freq_base, rope_freq_scale=rope_freq_scale, yarn_ext_factor=yarn_ext_factor, @@ -10542,7 +10670,7 @@ def __init__( raise RuntimeError("failed to create context") self.ctx = ctx mem = llama_cpp.llama_get_memory(ctx) - if mem is None: + if mem is None and not embedding: raise RuntimeError("failed to access model memory") self.mem = mem self.n_ctx = int(llama_cpp.llama_n_ctx(ctx)) @@ -10571,7 +10699,11 @@ def __init__( ) self.n_ctx_train = n_ctx_train self.n_vocab = int(llama_cpp.llama_vocab_n_tokens(self.vocab)) + self.n_embd = int(llama_cpp.llama_model_n_embd(self.llama_model)) self.n_embd_inp = int(llama_cpp.llama_model_n_embd_inp(self.llama_model)) + self.n_embd_out = int(llama_cpp.llama_model_n_embd_out(self.llama_model)) + if self.n_embd_out <= 0: + self.n_embd_out = self.n_embd self.kv_unified = kv_unified self.max_seq_len_limit = min(self.request_context_limit, self.n_ctx_train) if max_seq_len is None: @@ -10644,6 +10776,7 @@ def __init__( rope_scaling_type=rope_scaling_type, pooling_type=pooling_type, attention_type=attention_type, + embedding=embedding, rope_freq_base=rope_freq_base, rope_freq_scale=rope_freq_scale, yarn_ext_factor=yarn_ext_factor, @@ -10786,6 +10919,7 @@ def build_context_params( rope_scaling_type: Optional[int], pooling_type: Optional[int], attention_type: Optional[int], + embedding: bool, rope_freq_base: Optional[float], rope_freq_scale: Optional[float], yarn_ext_factor: Optional[float], @@ -10830,6 +10964,7 @@ def build_context_params( context_params.pooling_type = pooling_type if attention_type is not None: context_params.attention_type = attention_type + context_params.embeddings = embedding if rope_freq_base is not None: context_params.rope_freq_base = rope_freq_base if rope_freq_scale is not None: @@ -10940,14 +11075,34 @@ def close(self) -> None: llama_cpp.llama_backend_free() self.backend_initialized = False - def _meta_value(self, key: str) -> Optional[str]: + @staticmethod + def _model_meta_key_by_index(llama_model: Any, index: int) -> Optional[str]: + capacity = 256 + while True: + buffer = ctypes.create_string_buffer(capacity) + count = int( + llama_cpp.llama_model_meta_key_by_index( + llama_model, + index, + cast(Any, buffer), + capacity, + ) + ) + if count < 0: + return None + if count < capacity: + return buffer.value.decode("utf-8", errors="ignore") + capacity = count + 1 + + @staticmethod + def _model_meta_value(llama_model: Any, key: str) -> Optional[str]: encoded = key.encode("utf-8") capacity = 256 while True: buffer = ctypes.create_string_buffer(capacity) count = int( llama_cpp.llama_model_meta_val_str( - self.llama_model, + llama_model, encoded, cast(Any, buffer), capacity, @@ -10959,6 +11114,50 @@ def _meta_value(self, key: str) -> Optional[str]: return buffer.value.decode("utf-8", errors="ignore") capacity = count + 1 + @staticmethod + def _parse_pooling_type(value: str) -> Optional[int]: + normalized = value.strip().lower() + try: + return int(normalized) + except ValueError: + return { + "none": llama_cpp.LLAMA_POOLING_TYPE_NONE, + "mean": llama_cpp.LLAMA_POOLING_TYPE_MEAN, + "cls": llama_cpp.LLAMA_POOLING_TYPE_CLS, + "last": llama_cpp.LLAMA_POOLING_TYPE_LAST, + "rank": llama_cpp.LLAMA_POOLING_TYPE_RANK, + }.get(normalized) + + @classmethod + def detect_embedding_model(cls, llama_model: Any) -> bool: + for index in range(int(llama_cpp.llama_model_meta_count(llama_model))): + key = cls._model_meta_key_by_index(llama_model, index) + if key is None or not key.endswith(".pooling_type"): + continue + value = cls._model_meta_value(llama_model, key) + if value is None: + continue + pooling_type = cls._parse_pooling_type(value) + return pooling_type in { + llama_cpp.LLAMA_POOLING_TYPE_MEAN, + llama_cpp.LLAMA_POOLING_TYPE_CLS, + llama_cpp.LLAMA_POOLING_TYPE_LAST, + } + return False + + @classmethod + def resolve_embedding_mode( + cls, + llama_model: Any, + embedding: Optional[bool], + ) -> bool: + if embedding is not None: + return embedding + return cls.detect_embedding_model(llama_model) + + def _meta_value(self, key: str) -> Optional[str]: + return self._model_meta_value(self.llama_model, key) + def _build_chat_formatter(self) -> Optional[Jinja2ChatFormatter]: template_text = self.chat_template_override if template_text is None: @@ -11131,6 +11330,10 @@ def clear_batch(self) -> None: self._embedding_batch = None self._embedding_batch_refs = [] + def clear_memory(self) -> None: + if self.mem is not None: + llama_cpp.llama_memory_clear(self.mem, True) + def add_batch_tokens( self, *, @@ -11206,9 +11409,14 @@ def add_batch_embeddings( def decode(self) -> None: batch = self._embedding_batch if self._embedding_batch is not None else self.batch - result = int(llama_cpp.llama_decode(self.ctx, batch)) + if self.embedding and self.has_encoder: + operation = "llama_encode" + result = int(llama_cpp.llama_encode(self.ctx, batch)) + else: + operation = "llama_decode" + result = int(llama_cpp.llama_decode(self.ctx, batch)) if result != 0: - raise RuntimeError(f"llama_decode failed with code {result}") + raise RuntimeError(f"{operation} failed with code {result}") def process_draft_batch(self) -> None: if self.draft_provider is not None: @@ -11242,6 +11450,103 @@ def logits(self, output_index: int) -> np.ndarray: raise RuntimeError(f"missing logits output {output_index}") return np.ctypeslib.as_array(ptr, shape=(self.n_vocab,)).copy() + def embed( + self, + inputs: Sequence[Union[str, List[int]]], + ) -> Tuple[List[List[float]], int]: + if not self.embedding: + raise CompletionRequestValidationError( + "model.embedding must be true to use /v1/embeddings" + ) + pooling_type = int(llama_cpp.llama_pooling_type(self.ctx)) + if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE: + raise CompletionRequestValidationError( + "/v1/embeddings requires a pooled embedding model; " + "set model.pooling_type to MEAN, CLS, or LAST" + ) + if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_RANK: + raise CompletionRequestValidationError( + "/v1/embeddings does not support reranking pooling" + ) + if len(inputs) > 2048: + raise CompletionRequestValidationError( + "embedding input batch size exceeds 2048" + ) + + embeddings: List[List[float]] = [] + total_tokens = 0 + batch_sizes: List[int] = [] + batch_token_count = 0 + + def decode_embedding_batch() -> None: + nonlocal batch_token_count + if not batch_sizes: + return + self.clear_memory() + self.decode() + self.clear_batch() + for seq_id in range(len(batch_sizes)): + ptr = llama_cpp.llama_get_embeddings_seq( + self.ctx, + llama_cpp.llama_seq_id(seq_id), + ) + if not ptr: + raise RuntimeError(f"missing embedding output for input {seq_id}") + embeddings.append( + np.ctypeslib.as_array(ptr, shape=(self.n_embd_out,)).astype( + float + ).tolist() + ) + batch_sizes.clear() + batch_token_count = 0 + + try: + self.clear_batch() + self.clear_memory() + for input_item in inputs: + tokens = ( + self.tokenize(input_item) + if isinstance(input_item, str) + else list(input_item) + ) + n_tokens = len(tokens) + if n_tokens == 0: + raise CompletionRequestValidationError( + "embedding input must not be empty" + ) + if n_tokens > self.n_ctx_seq: + raise CompletionRequestValidationError( + f"embedding input has {n_tokens} tokens, exceeding n_ctx_seq ({self.n_ctx_seq})" + ) + if n_tokens > self.n_batch: + raise CompletionRequestValidationError( + f"embedding input has {n_tokens} tokens, exceeding n_batch ({self.n_batch})" + ) + if total_tokens + n_tokens > 300_000: + raise CompletionRequestValidationError( + "embedding request exceeds 300000 total tokens" + ) + if ( + batch_token_count + n_tokens > self.n_batch + or len(batch_sizes) >= self.n_seq_max + ): + decode_embedding_batch() + seq_id = len(batch_sizes) + self.add_batch_tokens( + seq_id=seq_id, + start_pos=0, + tokens=tokens, + output_indices=[0] * n_tokens, + ) + batch_sizes.append(n_tokens) + batch_token_count += n_tokens + total_tokens += n_tokens + decode_embedding_batch() + finally: + self.clear_batch() + self.clear_memory() + return embeddings, total_tokens + class SequenceDiskCache(SequenceCache): """Directory-backed cache for serialized llama.cpp sequence state.""" @@ -12065,6 +12370,37 @@ def build_memory_policy(self) -> MemoryPolicy: return PartitionedAttentionMemoryPolicy(self) return UnifiedAttentionMemoryPolicy(self) + def clear_resident_state(self) -> None: + self.model.clear_memory() + self.model.clear_batch() + self.radix_trie = RadixTrie() + self.sequence_history = SequenceHistory() + self.checkpoint_logits.clear() + self.claimed_sequences.clear() + self.free_sequences.clear() + self.unused_sequences = list(range(self.model.n_seq_max - 1, -1, -1)) + for seq_id in range(self.model.n_seq_max): + self.model.truncate_draft_sequence(seq_id, 0) + + def create_embedding( + self, + payload: CreateEmbeddingRequest, + ) -> CreateEmbeddingResponse: + if not self.is_idle(): + raise RuntimeError("embedding requests require an idle scheduler") + self.clear_resident_state() + try: + embeddings, total_tokens = self.model.embed(payload.normalized_input()) + return CreateEmbeddingResponse.from_embeddings( + model=payload.model, + embeddings=embeddings, + total_tokens=total_tokens, + encoding_format=payload.encoding_format, + dimensions=payload.dimensions, + ) + finally: + self.clear_resident_state() + @staticmethod def request_needs_prompt_logits(request: CompletionRequest) -> bool: return request.payload.max_tokens != 0 and request.effective_max_len > len( @@ -14420,6 +14756,39 @@ def run_callback() -> None: raise error_box["error"] return result_box.get("result") + def call_on_idle_scheduler(self, callback: Callable[[], Any]) -> Any: + result_box: Dict[str, Any] = {} + error_box: Dict[str, BaseException] = {} + done = threading.Event() + + def run_callback() -> None: + if not self.scheduler.is_idle(): + with self.condition: + self.commands.appendleft(run_callback) + self.condition.notify_all() + return + try: + result_box["result"] = callback() + except BaseException as exc: # noqa: BLE001 + error_box["error"] = exc + finally: + done.set() + + self.enqueue(run_callback) + done.wait() + if "error" in error_box: + raise error_box["error"] + return result_box.get("result") + + def create_embedding( + self, + payload: CreateEmbeddingRequest, + ) -> CreateEmbeddingResponse: + embedding = self.call_on_idle_scheduler( + lambda: self.scheduler.create_embedding(payload) + ) + return cast(CreateEmbeddingResponse, embedding) + def render_prometheus_metrics(self) -> str: metrics = self.call_on_scheduler(self.scheduler.render_prometheus_metrics) return cast(str, metrics) @@ -14874,6 +15243,17 @@ async def create_completion( # pyright: ignore[reportUnusedFunction] return result return JSONResponse(result.model_dump(mode="json", exclude_none=True)) + @app.post("/v1/embeddings", response_model=CreateEmbeddingResponse) + async def create_embedding( # pyright: ignore[reportUnusedFunction] + body: CreateEmbeddingRequest, + ) -> JSONResponse: + service: CompletionService = app.state.service + try: + embedding = await asyncio.to_thread(service.create_embedding, body) + except CompletionRequestValidationError as exc: + raise bad_request(exc) from exc + return JSONResponse(embedding.model_dump(mode="json", exclude_none=True)) + @app.post("/v1/chat/completions") async def create_chat_completion( # pyright: ignore[reportUnusedFunction] http_request: Request, body: CreateChatCompletionRequest @@ -15250,6 +15630,7 @@ def main() -> None: rope_scaling_type=config.model.rope_scaling_type, pooling_type=config.model.pooling_type, attention_type=config.model.attention_type, + embedding=config.model.embedding, rope_freq_base=config.model.rope_freq_base, rope_freq_scale=config.model.rope_freq_scale, yarn_ext_factor=config.model.yarn_ext_factor,