From 4d8d0e159ab47aa6c0d58518f4f02d6a79c52bb8 Mon Sep 17 00:00:00 2001 From: abetlen Date: Sun, 7 Jun 2026 06:15:27 -0700 Subject: [PATCH 1/4] feat(example): add embeddings endpoint --- examples/server/README.md | 27 +- .../server/configs/bge-small-en-v1.5.json | 23 ++ examples/server/server.py | 326 +++++++++++++++++- tests/test_server_example_embeddings.py | 110 ++++++ 4 files changed, 480 insertions(+), 6 deletions(-) create mode 100644 examples/server/configs/bge-small-en-v1.5.json create mode 100644 tests/test_server_example_embeddings.py diff --git a/examples/server/README.md b/examples/server/README.md index 1f6e0d3dbe..6e32444288 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -1,7 +1,7 @@ # Server Example This example is an updated OpenAI-compatible web server that depends only on the low-level C bindings. -It supports batched inference, prompt caching, response parsing, `/v1/responses`, disk sequence caching, MTP, LoRA, and multimodal image/audio inputs. +It supports batched inference, prompt caching, response parsing, `/v1/responses`, `/v1/embeddings`, disk sequence caching, MTP, LoRA, and multimodal image/audio inputs. ## Setup @@ -46,6 +46,7 @@ The smallest checked-in example uses Qwen3.5 0.8B so the server can be started o | Config | Model | Notes | | --- | --- | --- | +| [`configs/bge-small-en-v1.5.json`](configs/bge-small-en-v1.5.json) | [`CompendiumLabs/bge-small-en-v1.5-gguf`](https://huggingface.co/CompendiumLabs/bge-small-en-v1.5-gguf) | Small embedding model config for `/v1/embeddings`. | | [`configs/qwen3.5-0.8b.json`](configs/qwen3.5-0.8b.json) | [`lmstudio-community/Qwen3.5-0.8B-GGUF`](https://huggingface.co/lmstudio-community/Qwen3.5-0.8B-GGUF) | Default small multimodal example. | | [`configs/gemma-4-12b-it-qat.json`](configs/gemma-4-12b-it-qat.json) | [`unsloth/gemma-4-12B-it-qat-GGUF`](https://huggingface.co/unsloth/gemma-4-12B-it-qat-GGUF) | Larger Gemma 4 QAT multimodal config with projector. | | [`configs/qwen3.6-27b.json`](configs/qwen3.6-27b.json) | [`unsloth/Qwen3.6-27B-GGUF`](https://huggingface.co/unsloth/Qwen3.6-27B-GGUF) | Larger Qwen3.6 multimodal config. | @@ -86,11 +87,33 @@ response = client.responses.create( print(response.output_text) ``` +### Embeddings + +Start the server with an embedding config before calling `/v1/embeddings`. + +```bash +cd examples/server +uv run --script server.py -C configs/bge-small-en-v1.5.json +``` + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://127.0.0.1:8000/v1", api_key="not-used") + +response = client.embeddings.create( + model="bge-small-en-v1.5", + input=["The food was delicious.", "The meal was excellent."], +) +print(len(response.data[0].embedding)) +``` + ## API Surface | Endpoint | Purpose | Reference | | --- | --- | --- | | `POST /v1/completions` | Legacy text completions with streaming, stop sequences, logprobs, penalties, seeds, and grammar-backed JSON output. | [OpenAI Completions API](https://platform.openai.com/docs/api-reference/completions) | +| `POST /v1/embeddings` | OpenAI-compatible embeddings for embedding-mode GGUF models, including string inputs, token inputs, base64 output, and dimensions truncation. | [OpenAI Embeddings API](https://platform.openai.com/docs/api-reference/embeddings) | | `POST /v1/chat/completions` | Chat completions with streaming, tools, forced tool choice, reasoning parsing, multimodal content parts, and structured response parsing. | [OpenAI Chat API](https://platform.openai.com/docs/api-reference/chat) | | `POST /v1/responses` | Stateless Responses API compatibility for clients that use response items and response events. | [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) | | `WS /v1/responses` | Stateful websocket Responses transport with per-connection `previous_response_id` replay. | [OpenAI Responses API](https://platform.openai.com/docs/api-reference/responses) | @@ -190,6 +213,8 @@ Most model runtime fields map to `llama_model_params` or `llama_context_params` | `threads` | Decode thread count. | | `threads_batch` | Prefill and batch thread count. | | `kv_unified` | Selects unified or per-sequence memory layout. | +| `embedding` | Enables llama.cpp embedding extraction for `/v1/embeddings`. | +| `pooling_type` | Selects pooled embedding behavior for embedding models, such as `1` for mean pooling. | | `store_logits` | Keeps logits after decode when needed by sampling or diagnostics. | | `use_mmap` | Memory maps model weights. | | `use_mlock` | Attempts to lock model pages into RAM. | diff --git a/examples/server/configs/bge-small-en-v1.5.json b/examples/server/configs/bge-small-en-v1.5.json new file mode 100644 index 0000000000..cf8bd0c566 --- /dev/null +++ b/examples/server/configs/bge-small-en-v1.5.json @@ -0,0 +1,23 @@ +{ + "server": { + "host": "0.0.0.0", + "port": 8000 + }, + "model": { + "alias": "bge-small-en-v1.5", + "from_pretrained": { + "repo_id": "CompendiumLabs/bge-small-en-v1.5-gguf", + "filename": "bge-small-en-v1.5-q4_k_m.gguf" + }, + "embedding": true, + "n_ctx": 512, + "n_seq_max": 16, + "n_batch": 512, + "n_ubatch": 512, + "threads": 4, + "threads_batch": 8, + "kv_unified": true, + "store_logits": false, + "use_mmap": true + } +} diff --git a/examples/server/server.py b/examples/server/server.py index 590a214208..cbd92fe6fb 100644 --- a/examples/server/server.py +++ b/examples/server/server.py @@ -2457,6 +2457,7 @@ class CompletionChunk(TypedDict): CompletionStream = Generator[CompletionChunk, None, OpenAICompletion] CompletionPrompt = Union[str, List[int], List[str], List[List[int]]] +EmbeddingInput = Union[str, List[str], List[int], List[List[int]]] class CreateCompletionRequest(BaseModel): @@ -2522,6 +2523,126 @@ def normalized_prompt(self) -> List[Union[str, List[int]]]: raise ValueError("prompt must be a string, token ids, list of strings, or list of token-id lists") +class CreateEmbeddingRequest(BaseModel): + model_config = ConfigDict(extra="ignore") + + input: EmbeddingInput + model: str + encoding_format: Literal["float", "base64"] = "float" + dimensions: Optional[int] = Field(default=None, ge=1) + user: Optional[str] = None + + @staticmethod + def _validate_text_input(text: str) -> str: + if text == "": + raise ValueError("embedding input must not contain empty strings") + return text + + @staticmethod + def _validate_token_input(tokens: List[int]) -> List[int]: + if not tokens: + raise ValueError("embedding token input must not be empty") + if len(tokens) > 2048: + raise ValueError("embedding token input must not exceed 2048 tokens") + return tokens + + @model_validator(mode="after") + def validate_after(self) -> "CreateEmbeddingRequest": + self.normalized_input() + return self + + def normalized_input(self) -> List[Union[str, List[int]]]: + if isinstance(self.input, str): + return [self._validate_text_input(self.input)] + if all(isinstance(token, int) for token in self.input): + return [self._validate_token_input(cast(List[int], self.input))] + if all(isinstance(item, str) for item in self.input): + if len(self.input) > 2048: + raise ValueError("embedding input array must not exceed 2048 items") + return [ + self._validate_text_input(item) + for item in cast(List[str], self.input) + ] + if all( + isinstance(item, list) + and all(isinstance(token, int) for token in item) + for item in self.input + ): + if len(self.input) > 2048: + raise ValueError("embedding input array must not exceed 2048 items") + return [ + self._validate_token_input(item) + for item in cast(List[List[int]], self.input) + ] + raise ValueError( + "embedding input must be a string, list of strings, token ids, or list of token-id lists" + ) + + +class EmbeddingDataResponse(BaseModel): + object: Literal["embedding"] = "embedding" + embedding: Union[List[float], str] + index: int + + +class EmbeddingUsageResponse(BaseModel): + prompt_tokens: int + total_tokens: int + + +class CreateEmbeddingResponse(BaseModel): + object: Literal["list"] = "list" + data: List[EmbeddingDataResponse] + model: str + usage: EmbeddingUsageResponse + + @staticmethod + def encode_embedding( + embedding: Sequence[float], + encoding_format: Literal["float", "base64"], + dimensions: Optional[int], + ) -> Union[List[float], str]: + if dimensions is not None: + if dimensions > len(embedding): + raise CompletionRequestValidationError( + f"dimensions ({dimensions}) exceeds embedding size ({len(embedding)})" + ) + embedding = embedding[:dimensions] + if encoding_format == "float": + return [float(value) for value in embedding] + array = np.asarray(embedding, dtype=np.float32) + return base64.b64encode(array.tobytes()).decode("ascii") + + @classmethod + def from_embeddings( + cls, + *, + model: str, + embeddings: Sequence[Sequence[float]], + total_tokens: int, + encoding_format: Literal["float", "base64"], + dimensions: Optional[int], + ) -> "CreateEmbeddingResponse": + return cls( + data=[ + EmbeddingDataResponse( + embedding=cls.encode_embedding( + embedding, + encoding_format, + dimensions, + ), + index=index, + ) + for index, embedding in enumerate(embeddings) + ], + model=model, + usage=EmbeddingUsageResponse( + prompt_tokens=total_tokens, + total_tokens=total_tokens, + ), + ) + + class ChatCompletionFunctionCall(BaseModel): model_config = ConfigDict( extra="allow", @@ -3214,6 +3335,7 @@ class ModelOptions(BaseModel): rope_scaling_type: Optional[int] = None pooling_type: Optional[int] = None attention_type: Optional[int] = None + embedding: bool = False rope_freq_base: Optional[float] = None rope_freq_scale: Optional[float] = None yarn_ext_factor: Optional[float] = None @@ -10408,6 +10530,7 @@ def __init__( rope_scaling_type: Optional[int] = None, pooling_type: Optional[int] = None, attention_type: Optional[int] = None, + embedding: bool = False, rope_freq_base: Optional[float] = None, rope_freq_scale: Optional[float] = None, yarn_ext_factor: Optional[float] = None, @@ -10443,6 +10566,7 @@ def __init__( self.chat_template_override = chat_template self.response_schema = response_schema self.store_logits = store_logits + self.embedding = embedding self.max_output_tokens = max_output_tokens self.draft_model_max_batch_size = draft_model_max_batch_size self.draft_provider: Optional[DraftProvider] = None @@ -10473,9 +10597,11 @@ def __init__( if vocab is None: raise RuntimeError("failed to access model vocabulary") self.vocab = vocab - if llama_cpp.llama_model_has_encoder(llama_model): + self.has_encoder = bool(llama_cpp.llama_model_has_encoder(llama_model)) + self.has_decoder = bool(llama_cpp.llama_model_has_decoder(llama_model)) + if self.has_encoder and not embedding: raise RuntimeError("encoder models are not supported") - if not llama_cpp.llama_model_has_decoder(llama_model): + if not self.has_decoder and not (embedding and self.has_encoder): raise RuntimeError("decoder is required") if llama_cpp.llama_model_is_recurrent(llama_model): self.memory_model = "recurrent" @@ -10519,6 +10645,7 @@ def __init__( rope_scaling_type=rope_scaling_type, pooling_type=pooling_type, attention_type=attention_type, + embedding=embedding, rope_freq_base=rope_freq_base, rope_freq_scale=rope_freq_scale, yarn_ext_factor=yarn_ext_factor, @@ -10542,7 +10669,7 @@ def __init__( raise RuntimeError("failed to create context") self.ctx = ctx mem = llama_cpp.llama_get_memory(ctx) - if mem is None: + if mem is None and not embedding: raise RuntimeError("failed to access model memory") self.mem = mem self.n_ctx = int(llama_cpp.llama_n_ctx(ctx)) @@ -10571,7 +10698,11 @@ def __init__( ) self.n_ctx_train = n_ctx_train self.n_vocab = int(llama_cpp.llama_vocab_n_tokens(self.vocab)) + self.n_embd = int(llama_cpp.llama_model_n_embd(self.llama_model)) self.n_embd_inp = int(llama_cpp.llama_model_n_embd_inp(self.llama_model)) + self.n_embd_out = int(llama_cpp.llama_model_n_embd_out(self.llama_model)) + if self.n_embd_out <= 0: + self.n_embd_out = self.n_embd self.kv_unified = kv_unified self.max_seq_len_limit = min(self.request_context_limit, self.n_ctx_train) if max_seq_len is None: @@ -10644,6 +10775,7 @@ def __init__( rope_scaling_type=rope_scaling_type, pooling_type=pooling_type, attention_type=attention_type, + embedding=embedding, rope_freq_base=rope_freq_base, rope_freq_scale=rope_freq_scale, yarn_ext_factor=yarn_ext_factor, @@ -10786,6 +10918,7 @@ def build_context_params( rope_scaling_type: Optional[int], pooling_type: Optional[int], attention_type: Optional[int], + embedding: bool, rope_freq_base: Optional[float], rope_freq_scale: Optional[float], yarn_ext_factor: Optional[float], @@ -10830,6 +10963,7 @@ def build_context_params( context_params.pooling_type = pooling_type if attention_type is not None: context_params.attention_type = attention_type + context_params.embeddings = embedding if rope_freq_base is not None: context_params.rope_freq_base = rope_freq_base if rope_freq_scale is not None: @@ -11131,6 +11265,10 @@ def clear_batch(self) -> None: self._embedding_batch = None self._embedding_batch_refs = [] + def clear_memory(self) -> None: + if self.mem is not None: + llama_cpp.llama_memory_clear(self.mem, True) + def add_batch_tokens( self, *, @@ -11206,9 +11344,14 @@ def add_batch_embeddings( def decode(self) -> None: batch = self._embedding_batch if self._embedding_batch is not None else self.batch - result = int(llama_cpp.llama_decode(self.ctx, batch)) + if self.embedding and self.has_encoder: + operation = "llama_encode" + result = int(llama_cpp.llama_encode(self.ctx, batch)) + else: + operation = "llama_decode" + result = int(llama_cpp.llama_decode(self.ctx, batch)) if result != 0: - raise RuntimeError(f"llama_decode failed with code {result}") + raise RuntimeError(f"{operation} failed with code {result}") def process_draft_batch(self) -> None: if self.draft_provider is not None: @@ -11242,6 +11385,103 @@ def logits(self, output_index: int) -> np.ndarray: raise RuntimeError(f"missing logits output {output_index}") return np.ctypeslib.as_array(ptr, shape=(self.n_vocab,)).copy() + def embed( + self, + inputs: Sequence[Union[str, List[int]]], + ) -> Tuple[List[List[float]], int]: + if not self.embedding: + raise CompletionRequestValidationError( + "model.embedding must be true to use /v1/embeddings" + ) + pooling_type = int(llama_cpp.llama_pooling_type(self.ctx)) + if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_NONE: + raise CompletionRequestValidationError( + "/v1/embeddings requires a pooled embedding model; " + "set model.pooling_type to MEAN, CLS, or LAST" + ) + if pooling_type == llama_cpp.LLAMA_POOLING_TYPE_RANK: + raise CompletionRequestValidationError( + "/v1/embeddings does not support reranking pooling" + ) + if len(inputs) > 2048: + raise CompletionRequestValidationError( + "embedding input batch size exceeds 2048" + ) + + embeddings: List[List[float]] = [] + total_tokens = 0 + batch_sizes: List[int] = [] + batch_token_count = 0 + + def decode_embedding_batch() -> None: + nonlocal batch_token_count + if not batch_sizes: + return + self.clear_memory() + self.decode() + self.clear_batch() + for seq_id in range(len(batch_sizes)): + ptr = llama_cpp.llama_get_embeddings_seq( + self.ctx, + llama_cpp.llama_seq_id(seq_id), + ) + if not ptr: + raise RuntimeError(f"missing embedding output for input {seq_id}") + embeddings.append( + np.ctypeslib.as_array(ptr, shape=(self.n_embd_out,)).astype( + float + ).tolist() + ) + batch_sizes.clear() + batch_token_count = 0 + + try: + self.clear_batch() + self.clear_memory() + for input_item in inputs: + tokens = ( + self.tokenize(input_item) + if isinstance(input_item, str) + else list(input_item) + ) + n_tokens = len(tokens) + if n_tokens == 0: + raise CompletionRequestValidationError( + "embedding input must not be empty" + ) + if n_tokens > self.n_ctx_seq: + raise CompletionRequestValidationError( + f"embedding input has {n_tokens} tokens, exceeding n_ctx_seq ({self.n_ctx_seq})" + ) + if n_tokens > self.n_batch: + raise CompletionRequestValidationError( + f"embedding input has {n_tokens} tokens, exceeding n_batch ({self.n_batch})" + ) + if total_tokens + n_tokens > 300_000: + raise CompletionRequestValidationError( + "embedding request exceeds 300000 total tokens" + ) + if ( + batch_token_count + n_tokens > self.n_batch + or len(batch_sizes) >= self.n_seq_max + ): + decode_embedding_batch() + seq_id = len(batch_sizes) + self.add_batch_tokens( + seq_id=seq_id, + start_pos=0, + tokens=tokens, + output_indices=[0] * n_tokens, + ) + batch_sizes.append(n_tokens) + batch_token_count += n_tokens + total_tokens += n_tokens + decode_embedding_batch() + finally: + self.clear_batch() + self.clear_memory() + return embeddings, total_tokens + class SequenceDiskCache(SequenceCache): """Directory-backed cache for serialized llama.cpp sequence state.""" @@ -12065,6 +12305,37 @@ def build_memory_policy(self) -> MemoryPolicy: return PartitionedAttentionMemoryPolicy(self) return UnifiedAttentionMemoryPolicy(self) + def clear_resident_state(self) -> None: + self.model.clear_memory() + self.model.clear_batch() + self.radix_trie = RadixTrie() + self.sequence_history = SequenceHistory() + self.checkpoint_logits.clear() + self.claimed_sequences.clear() + self.free_sequences.clear() + self.unused_sequences = list(range(self.model.n_seq_max - 1, -1, -1)) + for seq_id in range(self.model.n_seq_max): + self.model.truncate_draft_sequence(seq_id, 0) + + def create_embedding( + self, + payload: CreateEmbeddingRequest, + ) -> CreateEmbeddingResponse: + if not self.is_idle(): + raise RuntimeError("embedding requests require an idle scheduler") + self.clear_resident_state() + try: + embeddings, total_tokens = self.model.embed(payload.normalized_input()) + return CreateEmbeddingResponse.from_embeddings( + model=payload.model, + embeddings=embeddings, + total_tokens=total_tokens, + encoding_format=payload.encoding_format, + dimensions=payload.dimensions, + ) + finally: + self.clear_resident_state() + @staticmethod def request_needs_prompt_logits(request: CompletionRequest) -> bool: return request.payload.max_tokens != 0 and request.effective_max_len > len( @@ -14420,6 +14691,39 @@ def run_callback() -> None: raise error_box["error"] return result_box.get("result") + def call_on_idle_scheduler(self, callback: Callable[[], Any]) -> Any: + result_box: Dict[str, Any] = {} + error_box: Dict[str, BaseException] = {} + done = threading.Event() + + def run_callback() -> None: + if not self.scheduler.is_idle(): + with self.condition: + self.commands.appendleft(run_callback) + self.condition.notify_all() + return + try: + result_box["result"] = callback() + except BaseException as exc: # noqa: BLE001 + error_box["error"] = exc + finally: + done.set() + + self.enqueue(run_callback) + done.wait() + if "error" in error_box: + raise error_box["error"] + return result_box.get("result") + + def create_embedding( + self, + payload: CreateEmbeddingRequest, + ) -> CreateEmbeddingResponse: + embedding = self.call_on_idle_scheduler( + lambda: self.scheduler.create_embedding(payload) + ) + return cast(CreateEmbeddingResponse, embedding) + def render_prometheus_metrics(self) -> str: metrics = self.call_on_scheduler(self.scheduler.render_prometheus_metrics) return cast(str, metrics) @@ -14874,6 +15178,17 @@ async def create_completion( # pyright: ignore[reportUnusedFunction] return result return JSONResponse(result.model_dump(mode="json", exclude_none=True)) + @app.post("/v1/embeddings", response_model=CreateEmbeddingResponse) + async def create_embedding( # pyright: ignore[reportUnusedFunction] + body: CreateEmbeddingRequest, + ) -> JSONResponse: + service: CompletionService = app.state.service + try: + embedding = await asyncio.to_thread(service.create_embedding, body) + except CompletionRequestValidationError as exc: + raise bad_request(exc) from exc + return JSONResponse(embedding.model_dump(mode="json", exclude_none=True)) + @app.post("/v1/chat/completions") async def create_chat_completion( # pyright: ignore[reportUnusedFunction] http_request: Request, body: CreateChatCompletionRequest @@ -15250,6 +15565,7 @@ def main() -> None: rope_scaling_type=config.model.rope_scaling_type, pooling_type=config.model.pooling_type, attention_type=config.model.attention_type, + embedding=config.model.embedding, rope_freq_base=config.model.rope_freq_base, rope_freq_scale=config.model.rope_freq_scale, yarn_ext_factor=config.model.yarn_ext_factor, diff --git a/tests/test_server_example_embeddings.py b/tests/test_server_example_embeddings.py new file mode 100644 index 0000000000..e3582d7d0a --- /dev/null +++ b/tests/test_server_example_embeddings.py @@ -0,0 +1,110 @@ +import base64 +import importlib.util +import sys +from pathlib import Path + +import numpy as np +import pytest +from pydantic import ValidationError + + +def load_server_module(): + path = Path(__file__).resolve().parents[1] / "examples" / "server" / "server.py" + spec = importlib.util.spec_from_file_location("example_server", path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +server = load_server_module() + + +def test_embedding_request_accepts_openai_input_shapes(): + assert server.CreateEmbeddingRequest( + input="hello", model="model" + ).normalized_input() == ["hello"] + assert server.CreateEmbeddingRequest( + input=["hello", "world"], + model="model", + ).normalized_input() == ["hello", "world"] + assert server.CreateEmbeddingRequest( + input=[1, 2, 3], model="model" + ).normalized_input() == [[1, 2, 3]] + assert server.CreateEmbeddingRequest( + input=[[1, 2], [3, 4]], + model="model", + ).normalized_input() == [[1, 2], [3, 4]] + + +@pytest.mark.parametrize( + "input_value", + [ + "", + [], + ["valid", ""], + [1, "mixed"], + [[]], + ["x"] * 2049, + [1] * 2049, + [[1] * 2049], + ], +) +def test_embedding_request_rejects_invalid_inputs(input_value): + with pytest.raises(ValidationError): + server.CreateEmbeddingRequest(input=input_value, model="model") + + +def test_embedding_response_supports_dimensions_for_float_output(): + response = server.CreateEmbeddingResponse.from_embeddings( + model="model", + embeddings=[[1.0, 2.0, 3.0]], + total_tokens=3, + encoding_format="float", + dimensions=2, + ) + + assert response.model_dump(mode="json") == { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [1.0, 2.0], + "index": 0, + } + ], + "model": "model", + "usage": { + "prompt_tokens": 3, + "total_tokens": 3, + }, + } + + +def test_embedding_response_supports_base64_output(): + response = server.CreateEmbeddingResponse.from_embeddings( + model="model", + embeddings=[[1.0, 2.0, 3.0]], + total_tokens=3, + encoding_format="base64", + dimensions=2, + ) + + payload = response.model_dump(mode="json") + encoded = payload["data"][0]["embedding"] + assert isinstance(encoded, str) + decoded = np.frombuffer(base64.b64decode(encoded), dtype=np.float32) + np.testing.assert_allclose(decoded, np.array([1.0, 2.0], dtype=np.float32)) + + +def test_embedding_response_rejects_oversized_dimensions(): + with pytest.raises(server.CompletionRequestValidationError): + server.CreateEmbeddingResponse.from_embeddings( + model="model", + embeddings=[[1.0, 2.0, 3.0]], + total_tokens=3, + encoding_format="float", + dimensions=4, + ) From 05315a06d5fd4e36874274e939c9822ce9fbf014 Mon Sep 17 00:00:00 2001 From: abetlen Date: Sun, 7 Jun 2026 06:16:23 -0700 Subject: [PATCH 2/4] docs: update changelog for embeddings endpoint --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 229921847f..4912e00e45 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- feat(example): add OpenAI-compatible embeddings endpoint by @abetlen in #2281 + ## [0.3.27] - feat: update llama.cpp to ggml-org/llama.cpp@465b1f0e7 From bd0837ec1373fba602dc20ab397e61e8e6504c2c Mon Sep 17 00:00:00 2001 From: abetlen Date: Sun, 7 Jun 2026 06:30:44 -0700 Subject: [PATCH 3/4] test: remove server embeddings example test --- tests/test_server_example_embeddings.py | 110 ------------------------ 1 file changed, 110 deletions(-) delete mode 100644 tests/test_server_example_embeddings.py diff --git a/tests/test_server_example_embeddings.py b/tests/test_server_example_embeddings.py deleted file mode 100644 index e3582d7d0a..0000000000 --- a/tests/test_server_example_embeddings.py +++ /dev/null @@ -1,110 +0,0 @@ -import base64 -import importlib.util -import sys -from pathlib import Path - -import numpy as np -import pytest -from pydantic import ValidationError - - -def load_server_module(): - path = Path(__file__).resolve().parents[1] / "examples" / "server" / "server.py" - spec = importlib.util.spec_from_file_location("example_server", path) - assert spec is not None - assert spec.loader is not None - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) - return module - - -server = load_server_module() - - -def test_embedding_request_accepts_openai_input_shapes(): - assert server.CreateEmbeddingRequest( - input="hello", model="model" - ).normalized_input() == ["hello"] - assert server.CreateEmbeddingRequest( - input=["hello", "world"], - model="model", - ).normalized_input() == ["hello", "world"] - assert server.CreateEmbeddingRequest( - input=[1, 2, 3], model="model" - ).normalized_input() == [[1, 2, 3]] - assert server.CreateEmbeddingRequest( - input=[[1, 2], [3, 4]], - model="model", - ).normalized_input() == [[1, 2], [3, 4]] - - -@pytest.mark.parametrize( - "input_value", - [ - "", - [], - ["valid", ""], - [1, "mixed"], - [[]], - ["x"] * 2049, - [1] * 2049, - [[1] * 2049], - ], -) -def test_embedding_request_rejects_invalid_inputs(input_value): - with pytest.raises(ValidationError): - server.CreateEmbeddingRequest(input=input_value, model="model") - - -def test_embedding_response_supports_dimensions_for_float_output(): - response = server.CreateEmbeddingResponse.from_embeddings( - model="model", - embeddings=[[1.0, 2.0, 3.0]], - total_tokens=3, - encoding_format="float", - dimensions=2, - ) - - assert response.model_dump(mode="json") == { - "object": "list", - "data": [ - { - "object": "embedding", - "embedding": [1.0, 2.0], - "index": 0, - } - ], - "model": "model", - "usage": { - "prompt_tokens": 3, - "total_tokens": 3, - }, - } - - -def test_embedding_response_supports_base64_output(): - response = server.CreateEmbeddingResponse.from_embeddings( - model="model", - embeddings=[[1.0, 2.0, 3.0]], - total_tokens=3, - encoding_format="base64", - dimensions=2, - ) - - payload = response.model_dump(mode="json") - encoded = payload["data"][0]["embedding"] - assert isinstance(encoded, str) - decoded = np.frombuffer(base64.b64decode(encoded), dtype=np.float32) - np.testing.assert_allclose(decoded, np.array([1.0, 2.0], dtype=np.float32)) - - -def test_embedding_response_rejects_oversized_dimensions(): - with pytest.raises(server.CompletionRequestValidationError): - server.CreateEmbeddingResponse.from_embeddings( - model="model", - embeddings=[[1.0, 2.0, 3.0]], - total_tokens=3, - encoding_format="float", - dimensions=4, - ) From b9913a2865c010de10d1852f942103f1729666bd Mon Sep 17 00:00:00 2001 From: abetlen Date: Sun, 7 Jun 2026 14:53:44 -0700 Subject: [PATCH 4/4] feat(example): auto-detect embedding model mode --- examples/server/README.md | 4 +- .../server/configs/bge-small-en-v1.5.json | 1 - examples/server/server.py | 75 +++++++++++++++++-- 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 6e32444288..6a1cc44480 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -213,8 +213,8 @@ Most model runtime fields map to `llama_model_params` or `llama_context_params` | `threads` | Decode thread count. | | `threads_batch` | Prefill and batch thread count. | | `kv_unified` | Selects unified or per-sequence memory layout. | -| `embedding` | Enables llama.cpp embedding extraction for `/v1/embeddings`. | -| `pooling_type` | Selects pooled embedding behavior for embedding models, such as `1` for mean pooling. | +| `embedding` | Overrides embedding mode; omit to auto-detect pooled embedding GGUFs from model metadata. | +| `pooling_type` | Overrides pooled embedding behavior for embedding models, such as `1` for mean pooling. | | `store_logits` | Keeps logits after decode when needed by sampling or diagnostics. | | `use_mmap` | Memory maps model weights. | | `use_mlock` | Attempts to lock model pages into RAM. | diff --git a/examples/server/configs/bge-small-en-v1.5.json b/examples/server/configs/bge-small-en-v1.5.json index cf8bd0c566..3fc8016df8 100644 --- a/examples/server/configs/bge-small-en-v1.5.json +++ b/examples/server/configs/bge-small-en-v1.5.json @@ -9,7 +9,6 @@ "repo_id": "CompendiumLabs/bge-small-en-v1.5-gguf", "filename": "bge-small-en-v1.5-q4_k_m.gguf" }, - "embedding": true, "n_ctx": 512, "n_seq_max": 16, "n_batch": 512, diff --git a/examples/server/server.py b/examples/server/server.py index cbd92fe6fb..2530cf4bec 100644 --- a/examples/server/server.py +++ b/examples/server/server.py @@ -3335,7 +3335,7 @@ class ModelOptions(BaseModel): rope_scaling_type: Optional[int] = None pooling_type: Optional[int] = None attention_type: Optional[int] = None - embedding: bool = False + embedding: Optional[bool] = None rope_freq_base: Optional[float] = None rope_freq_scale: Optional[float] = None yarn_ext_factor: Optional[float] = None @@ -10530,7 +10530,7 @@ def __init__( rope_scaling_type: Optional[int] = None, pooling_type: Optional[int] = None, attention_type: Optional[int] = None, - embedding: bool = False, + embedding: Optional[bool] = None, rope_freq_base: Optional[float] = None, rope_freq_scale: Optional[float] = None, yarn_ext_factor: Optional[float] = None, @@ -10566,7 +10566,6 @@ def __init__( self.chat_template_override = chat_template self.response_schema = response_schema self.store_logits = store_logits - self.embedding = embedding self.max_output_tokens = max_output_tokens self.draft_model_max_batch_size = draft_model_max_batch_size self.draft_provider: Optional[DraftProvider] = None @@ -10597,6 +10596,8 @@ def __init__( if vocab is None: raise RuntimeError("failed to access model vocabulary") self.vocab = vocab + embedding = self.resolve_embedding_mode(llama_model, embedding) + self.embedding = embedding self.has_encoder = bool(llama_cpp.llama_model_has_encoder(llama_model)) self.has_decoder = bool(llama_cpp.llama_model_has_decoder(llama_model)) if self.has_encoder and not embedding: @@ -11074,14 +11075,34 @@ def close(self) -> None: llama_cpp.llama_backend_free() self.backend_initialized = False - def _meta_value(self, key: str) -> Optional[str]: + @staticmethod + def _model_meta_key_by_index(llama_model: Any, index: int) -> Optional[str]: + capacity = 256 + while True: + buffer = ctypes.create_string_buffer(capacity) + count = int( + llama_cpp.llama_model_meta_key_by_index( + llama_model, + index, + cast(Any, buffer), + capacity, + ) + ) + if count < 0: + return None + if count < capacity: + return buffer.value.decode("utf-8", errors="ignore") + capacity = count + 1 + + @staticmethod + def _model_meta_value(llama_model: Any, key: str) -> Optional[str]: encoded = key.encode("utf-8") capacity = 256 while True: buffer = ctypes.create_string_buffer(capacity) count = int( llama_cpp.llama_model_meta_val_str( - self.llama_model, + llama_model, encoded, cast(Any, buffer), capacity, @@ -11093,6 +11114,50 @@ def _meta_value(self, key: str) -> Optional[str]: return buffer.value.decode("utf-8", errors="ignore") capacity = count + 1 + @staticmethod + def _parse_pooling_type(value: str) -> Optional[int]: + normalized = value.strip().lower() + try: + return int(normalized) + except ValueError: + return { + "none": llama_cpp.LLAMA_POOLING_TYPE_NONE, + "mean": llama_cpp.LLAMA_POOLING_TYPE_MEAN, + "cls": llama_cpp.LLAMA_POOLING_TYPE_CLS, + "last": llama_cpp.LLAMA_POOLING_TYPE_LAST, + "rank": llama_cpp.LLAMA_POOLING_TYPE_RANK, + }.get(normalized) + + @classmethod + def detect_embedding_model(cls, llama_model: Any) -> bool: + for index in range(int(llama_cpp.llama_model_meta_count(llama_model))): + key = cls._model_meta_key_by_index(llama_model, index) + if key is None or not key.endswith(".pooling_type"): + continue + value = cls._model_meta_value(llama_model, key) + if value is None: + continue + pooling_type = cls._parse_pooling_type(value) + return pooling_type in { + llama_cpp.LLAMA_POOLING_TYPE_MEAN, + llama_cpp.LLAMA_POOLING_TYPE_CLS, + llama_cpp.LLAMA_POOLING_TYPE_LAST, + } + return False + + @classmethod + def resolve_embedding_mode( + cls, + llama_model: Any, + embedding: Optional[bool], + ) -> bool: + if embedding is not None: + return embedding + return cls.detect_embedding_model(llama_model) + + def _meta_value(self, key: str) -> Optional[str]: + return self._model_meta_value(self.llama_model, key) + def _build_chat_formatter(self) -> Optional[Jinja2ChatFormatter]: template_text = self.chat_template_override if template_text is None: