Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ dependencies = [
"python-dotenv>=1.2.2",
# Used for token estimation before LLM calls (LCORE-1569 / conversation compaction)
"tiktoken>=0.8.0",
# Used for Pydantic AI
"pydantic-ai>=1.99.0"
]


Expand Down
1 change: 1 addition & 0 deletions src/pydantic_ai_lightspeed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Pydantic AI integrations/extensions for Lightspeed Core Stack."""
5 changes: 5 additions & 0 deletions src/pydantic_ai_lightspeed/llamastack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Pydantic AI provider for Llama Stack."""

from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider

__all__ = ["LlamaStackProvider"]
123 changes: 123 additions & 0 deletions src/pydantic_ai_lightspeed/llamastack/_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Llama Stack provider implementation for Pydantic AI."""

from __future__ import annotations as _annotations

from typing import TYPE_CHECKING

import httpx
from openai import AsyncOpenAI
from pydantic_ai import ModelProfile
from pydantic_ai.models import create_async_http_client
from pydantic_ai.profiles.openai import openai_model_profile
from pydantic_ai.providers import Provider

from pydantic_ai_lightspeed.llamastack._transport import LlamaStackLibraryTransport

if TYPE_CHECKING:
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient

DEFAULT_BASE_URL = "http://localhost:8321/v1"


class LlamaStackProvider(Provider[AsyncOpenAI]):
"""Provider for Llama Stack — connects to a Llama Stack server's OpenAI-compatible API.

Supports two modes:

1. **Server mode** — connect to a running Llama Stack server via HTTP
2. **Library mode** — run Llama Stack in-process via ``AsyncLlamaStackAsLibraryClient``
"""

@property
def name(self) -> str:
"""The provider name."""
return "llama-stack"

@property
def base_url(self) -> str:
"""The base URL for the provider API."""
return str(self._client.base_url)

@property
def client(self) -> AsyncOpenAI:
"""The OpenAI-compatible client for the provider."""
return self._client

@staticmethod
def model_profile(model_name: str) -> ModelProfile | None:
"""Return the model profile for the named model, if available."""
return openai_model_profile(model_name)

def __init__(
self,
*,
base_url: str | None = None,
api_key: str | None = None,
library_client: AsyncLlamaStackAsLibraryClient | None = None,
http_client: httpx.AsyncClient | None = None,
) -> None:
"""Create a new Llama Stack provider.

Args:
base_url: The base URL for the Llama Stack server (OpenAI-compatible endpoint).
Defaults to ``http://localhost:8321/v1``.
Must be ``None`` when ``library_client`` is provided.
api_key: The API key for authentication. Defaults to ``'not-needed'`` since
local Llama Stack servers typically don't require one.
Must be ``None`` when ``library_client`` is provided.
library_client: An initialized ``AsyncLlamaStackAsLibraryClient`` for library mode.
When provided, requests are dispatched in-process (no server needed).
Mutually exclusive with ``base_url``, ``api_key``, and ``http_client``.
http_client: An existing ``httpx.AsyncClient`` to use for making HTTP requests.
Must be ``None`` when ``library_client`` is provided.
"""
if library_client is not None:
if base_url is not None:
raise ValueError("Cannot provide both `library_client` and `base_url`")
if api_key is not None:
raise ValueError("Cannot provide both `library_client` and `api_key`")
if http_client is not None:
raise ValueError(
"Cannot provide both `library_client` and `http_client`"
)

self._library_client = library_client
transport = LlamaStackLibraryTransport(library_client)
lib_http_client = httpx.AsyncClient(
transport=transport,

Check warning

Code scanning / Bandit

Call to httpx without timeout Warning

Call to httpx without timeout
Comment on lines +86 to +87
base_url="http://llama-stack-library",
timeout=httpx.Timeout(None),
)
self._client = AsyncOpenAI(
http_client=lib_http_client,
base_url="http://llama-stack-library/v1",
api_key="not-needed",
)
else:
base_url = base_url or DEFAULT_BASE_URL
api_key = api_key or "not-needed"

if http_client is not None:
self._client = AsyncOpenAI(
base_url=base_url, api_key=api_key, http_client=http_client
)
else:
oai_http_client = create_async_http_client()
self._client = AsyncOpenAI(
base_url=base_url, api_key=api_key, http_client=oai_http_client
)

def __repr__(self) -> str:
"""Return a string representation of the provider."""
return f"LlamaStackProvider(name={self.name!r}, base_url={self.base_url!r})"

def _set_http_client(self, http_client: httpx.AsyncClient) -> None:
"""Inject an httpx.AsyncClient into the underlying OpenAI client.

Replaces the internal HTTP transport by assigning directly to the
protected ``self._client._client`` attribute of the AsyncOpenAI instance.

Args:
http_client: The async HTTP client to use for subsequent requests.
"""
self._client._client = http_client # pyright: ignore[reportPrivateUsage] # pylint: disable=protected-access
Comment thread
coderabbitai[bot] marked this conversation as resolved.
197 changes: 197 additions & 0 deletions src/pydantic_ai_lightspeed/llamastack/_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""httpx transport that routes OpenAI-compatible requests through a Llama Stack library client."""

from __future__ import annotations as _annotations

import json
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any

import httpx
from llama_stack.core.library_client import (
AsyncLlamaStackAsLibraryClient,
convert_pydantic_to_json_value,
)
from llama_stack.core.request_headers import (
PROVIDER_DATA_VAR,
request_provider_data_context,
)
from llama_stack.core.server.routes import find_matching_route
from llama_stack.core.utils.context import preserve_contexts_async_generator


class _AsyncByteStream(httpx.AsyncByteStream):
"""Wraps an async byte generator as an httpx AsyncByteStream."""

def __init__(self, gen: AsyncGenerator[bytes, None]) -> None:
"""Store an async generator that yields raw bytes for streaming.

Args:
gen: An async generator producing byte chunks to stream.
"""
self._gen = gen

async def __aiter__(self) -> AsyncIterator[bytes]:
"""Yield bytes chunks from the wrapped generator.

Returns:
An async iterator of bytes fulfilling the httpx.AsyncByteStream contract.
"""
async for chunk in self._gen:
yield chunk
Comment thread
coderabbitai[bot] marked this conversation as resolved.


class LlamaStackLibraryTransport(httpx.AsyncBaseTransport):
"""Custom httpx transport that dispatches requests through a Llama Stack library client.

Instead of making real HTTP calls, this transport routes requests directly
to the Llama Stack's in-process route handlers via the library client's
route matching and body conversion logic.
"""

def __init__(self, client: AsyncLlamaStackAsLibraryClient) -> None:
"""Initialize the transport with a Llama Stack library client.

Args:
client: An initialized ``AsyncLlamaStackAsLibraryClient`` whose route
handlers will receive dispatched requests.
"""
self._client = client

async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
"""Dispatch an httpx request to the in-process Llama Stack route handlers.

Args:
request: The outgoing httpx request to route.

Returns:
An httpx response built from the matched route handler result.

Raises:
RuntimeError: If the library client has not been initialized.
"""
if self._client.route_impls is None:
raise RuntimeError(
"Llama Stack library client not initialized. Call initialize() first."
)

method = request.method
path = request.url.raw_path.decode("utf-8")

body = json.loads(request.content) if request.content else {}

headers: dict[str, str] = {
k.decode("utf-8") if isinstance(k, bytes) else k: (
v.decode("utf-8") if isinstance(v, bytes) else v
)
for k, v in request.headers.raw
}

if self._client.provider_data:
keys = ["X-LlamaStack-Provider-Data", "x-llamastack-provider-data"]
if all(key not in headers for key in keys):
headers["X-LlamaStack-Provider-Data"] = json.dumps(
self._client.provider_data
)

with request_provider_data_context(headers):
is_stream = body.get("stream", False)

if is_stream:
return await self._handle_streaming(request, method, path, body)
return await self._handle_non_streaming(request, method, path, body)

async def _handle_non_streaming(
self,
request: httpx.Request,
method: str,
path: str,
body: dict[str, Any],
) -> httpx.Response:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"""Dispatch a non-streaming request to the matched route handler.

Args:
request: The original httpx request (attached to the response).
method: The HTTP method (e.g. ``"POST"``).
path: The decoded URL path used for route matching.
body: The parsed JSON request body.

Returns:
An httpx.Response containing the JSON-serialized handler result.

Raises:
RuntimeError: If route_impls is not initialized.
"""
if self._client.route_impls is None:
raise RuntimeError("route_impls is not initialized")

matched_func, path_params, _, _ = find_matching_route(
method, path, self._client.route_impls
)
merged_body = {**body, **path_params}
merged_body = self._client._convert_body( # pylint: disable=protected-access
matched_func, merged_body
)

result = await matched_func(**merged_body)

json_content = json.dumps(convert_pydantic_to_json_value(result))
status_code = httpx.codes.OK

if method.upper() == "DELETE" and result is None:
status_code = httpx.codes.NO_CONTENT
json_content = ""

return httpx.Response(
status_code=status_code,
content=json_content.encode("utf-8"),
headers={"Content-Type": "application/json"},
request=request,
)

async def _handle_streaming(
self,
request: httpx.Request,
method: str,
path: str,
body: dict[str, Any],
) -> httpx.Response:
"""Dispatch a streaming request and return an SSE event-stream response.

Args:
request: The original httpx request (attached to the response).
method: The HTTP method (e.g. ``"POST"``).
path: The decoded URL path used for route matching.
body: The parsed JSON request body (must contain ``stream: True``).

Returns:
An httpx.Response with a streaming body of SSE-formatted chunks.

Raises:
RuntimeError: If route_impls is not initialized.
"""
if self._client.route_impls is None:
raise RuntimeError("route_impls is not initialized")

func, path_params, _, _ = find_matching_route(
method, path, self._client.route_impls
)
merged_body = {**body, **path_params}
merged_body = self._client._convert_body( # pylint: disable=protected-access
func, merged_body
)

result = await func(**merged_body)

async def gen() -> AsyncGenerator[bytes, None]:
async for chunk in result:
data = json.dumps(convert_pydantic_to_json_value(chunk))
yield f"data: {data}\n\n".encode("utf-8")

wrapped_gen = preserve_contexts_async_generator(gen(), [PROVIDER_DATA_VAR])

return httpx.Response(
status_code=httpx.codes.OK,
stream=_AsyncByteStream(wrapped_gen),
headers={"Content-Type": "text/event-stream"},
request=request,
)
1 change: 1 addition & 0 deletions tests/unit/pydantic_ai_lightspeed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for the pydantic_ai_lightspeed package."""
1 change: 1 addition & 0 deletions tests/unit/pydantic_ai_lightspeed/llamastack/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Unit tests for pydantic_ai_lightspeed.llamastack sub-package."""
Loading
Loading