diff --git a/samples/README.md b/samples/README.md index e61264955..5034c1baf 100644 --- a/samples/README.md +++ b/samples/README.md @@ -8,8 +8,9 @@ This directory contains runnable examples demonstrating how to build and interac |---|---|---| | `hello_world_agent.py` | **Server** | A2A agent server | | `cli.py` | **Client** | Interactive terminal client | +| `text_client_cli.py` | **Client** | Simplified text-only interactive terminal client | -The samples are designed to work together out of the box: the agent listens on `http://127.0.0.1:41241`, which is the default URL used by the client. +All three samples are designed to work together out of the box: the agent listens on `http://127.0.0.1:41241`, which is the default URL used by both clients. --- ## `hello_world_agent.py` — Agent Server @@ -53,6 +54,44 @@ uv run python samples/cli.py uv run python samples/cli.py --url http://192.168.1.10:41241 --transport GRPC ``` -Then type a message like `hello` and press Enter. +Type `/quit` or `/exit` to stop, or press `Ctrl+C`. + +--- + +## `text_client_cli.py` — Simple Text Client + +A stripped-down interactive client using the high-level `TextClient` abstraction. It hides all streaming and event mechanics, presenting a simple request/response interface. + +Ideal for understanding the **minimum code required** to call an A2A agent. + +**Run:** + +```bash +# Connect to the local hello_world_agent (default): +uv run python samples/text_client_cli.py + +# Connect to a different URL: +uv run python samples/text_client_cli.py --url http://192.168.1.10:41241 + +# Use a specific transport: +uv run python samples/text_client_cli.py --transport GRPC +``` Type `/quit` or `/exit` to stop, or press `Ctrl+C`. + +--- + + +## Quick Start + +In two separate terminals: + +```bash +# Terminal 1 — start the agent +uv run python samples/hello_world_agent.py + +# Terminal 2 — start the client +uv run python samples/cli.py +``` + +Then type a message like `hello` and press Enter. diff --git a/samples/cli.py b/samples/cli.py index 54b68388f..416d4c31b 100644 --- a/samples/cli.py +++ b/samples/cli.py @@ -76,6 +76,8 @@ async def main() -> None: config = ClientConfig() if args.transport: config.supported_protocol_bindings = [args.transport] + if args.transport == 'GRPC': + config.grpc_channel_factory = grpc.aio.insecure_channel print( f'Connecting to {args.url} (preferred transport: {args.transport or "Any"})' diff --git a/samples/text_client_cli.py b/samples/text_client_cli.py new file mode 100644 index 000000000..67fd33ea8 --- /dev/null +++ b/samples/text_client_cli.py @@ -0,0 +1,70 @@ +import argparse +import asyncio + +import grpc +import httpx + +from a2a.client import A2ACardResolver, ClientConfig, create_text_client + + +async def main() -> None: + """Run the simple A2A terminal client using TextClient.""" + parser = argparse.ArgumentParser(description='A2A Simple Text Client') + parser.add_argument( + '--url', default='http://127.0.0.1:41241', help='Agent base URL' + ) + parser.add_argument( + '--transport', + default=None, + help='Preferred transport (JSONRPC, HTTP+JSON, GRPC)', + ) + args = parser.parse_args() + + config = ClientConfig() + if args.transport: + config.supported_protocol_bindings = [args.transport] + if args.transport == 'GRPC': + config.grpc_channel_factory = grpc.aio.insecure_channel + + print( + f'Connecting to {args.url} (preferred transport: {args.transport or "Any"})' + ) + + async with httpx.AsyncClient() as httpx_client: + resolver = A2ACardResolver(httpx_client, args.url) + card = await resolver.get_agent_card() + print('\n✓ Agent Card Found:') + print(f' Name: {card.name}') + + text_client = await create_text_client(card, client_config=config) + + actual_transport = getattr( + text_client.client, '_transport', text_client.client + ) + print(f' Picked Transport: {actual_transport.__class__.__name__}') + + print('\nConnected! Send a message or type /quit to exit.') + + while True: + try: + loop = asyncio.get_running_loop() + user_input = await loop.run_in_executor(None, input, 'You: ') + except KeyboardInterrupt: + break + + if user_input.lower() in ('/quit', '/exit'): + break + if not user_input.strip(): + continue + + try: + response = await text_client.send_text_message(user_input) + print(f'Agent: {response}') + except (httpx.RequestError, grpc.RpcError) as e: + print(f'Error communicating with agent: {e}') + + await text_client.close() + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index c23041f32..4a6d05a27 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -15,6 +15,7 @@ from a2a.client.client_factory import ( ClientFactory, create_client, + create_text_client, minimal_agent_card, ) from a2a.client.errors import ( @@ -24,6 +25,7 @@ ) from a2a.client.helpers import create_text_message_object from a2a.client.interceptors import ClientCallInterceptor +from a2a.client.text_client import TextClient __all__ = [ @@ -40,7 +42,9 @@ 'ClientFactory', 'CredentialService', 'InMemoryContextCredentialStore', + 'TextClient', 'create_client', + 'create_text_client', 'create_text_message_object', 'minimal_agent_card', ] diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index a59189ade..320b415c3 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -12,6 +12,7 @@ from a2a.client.base_client import BaseClient from a2a.client.card_resolver import A2ACardResolver from a2a.client.client import Client, ClientConfig +from a2a.client.text_client import TextClient from a2a.client.transports.base import ClientTransport from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport @@ -406,6 +407,47 @@ async def create_client( # noqa: PLR0913 return factory.create(agent, interceptors) +async def create_text_client( # noqa: PLR0913 + agent: str | AgentCard, + client_config: ClientConfig | None = None, + interceptors: list[ClientCallInterceptor] | None = None, + relative_card_path: str | None = None, + resolver_http_kwargs: dict[str, Any] | None = None, + signature_verifier: Callable[[AgentCard], None] | None = None, +) -> TextClient: + """Create a `TextClient` for an agent from a URL or `AgentCard`. + + Convenience function that constructs a `ClientFactory` internally. + For reusing a factory across multiple agents or registering custom + transports, use `ClientFactory` directly instead. + + Args: + agent: The base URL of the agent, or an `AgentCard` to use + directly. + client_config: Optional `ClientConfig`. A default config is + created if not provided. + interceptors: A list of interceptors to use for each request. + relative_card_path: The relative path when resolving the agent + card. Only used when `agent` is a URL. + resolver_http_kwargs: Dictionary of arguments to provide to the + httpx client when resolving the agent card. + signature_verifier: A callable used to verify the agent card's + signatures. + + Returns: + A `TextClient` wrapping the constructed `Client`. + """ + client = await create_client( + agent=agent, + client_config=client_config, + interceptors=interceptors, + relative_card_path=relative_card_path, + resolver_http_kwargs=resolver_http_kwargs, + signature_verifier=signature_verifier, + ) + return TextClient(client) + + def minimal_agent_card( url: str, transports: list[str] | None = None ) -> AgentCard: diff --git a/src/a2a/client/text_client.py b/src/a2a/client/text_client.py new file mode 100644 index 000000000..90e5040aa --- /dev/null +++ b/src/a2a/client/text_client.py @@ -0,0 +1,114 @@ +import uuid + +from types import TracebackType + +from typing_extensions import Self + +from a2a.client.client import Client, ClientCallContext +from a2a.types import Message, Part, Role, SendMessageRequest, TaskState +from a2a.utils import get_artifact_text, get_message_text + + +_TERMINAL_STATES: frozenset[TaskState] = frozenset( + { + TaskState.TASK_STATE_COMPLETED, + TaskState.TASK_STATE_FAILED, + TaskState.TASK_STATE_CANCELED, + TaskState.TASK_STATE_REJECTED, + } +) + + +class TextClient: + """A facade around Client that simplifies text-based communication. + + Wraps an underlying Client instance and exposes a simplified interface + for sending plain-text messages and receiving aggregated text responses. + Maintains session state (context_id, task_id) automatically across calls. + For full Client API access, use the underlying client directly via + the `client` property. + """ + + def __init__(self, client: Client): + self._client = client + self._context_id: str = str(uuid.uuid4()) + self._task_id: str | None = None + + async def __aenter__(self) -> Self: + """Enters the async context manager.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exits the async context manager and closes the client.""" + await self.close() + + @property + def client(self) -> Client: + """Returns the underlying Client instance for full API access.""" + return self._client + + def reset_session(self) -> None: + """Starts a new session by generating a fresh context ID and clearing the task ID.""" + self._context_id = str(uuid.uuid4()) + self._task_id = None + + async def send_text_message( + self, + text: str, + *, + delimiter: str = ' ', + context: ClientCallContext | None = None, + ) -> str: + """Sends a text message and returns the aggregated text response. + + Session state (context_id, task_id) is managed automatically across + calls. Use reset_session() to start a new conversation. + + Args: + text: The plain-text message to send. + delimiter: String used to join response parts. Defaults to a + single space. Use '' for token-streamed responses or a + newline for paragraph-separated chunks. + context: Optional call-level context. + """ + request = SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id=str(uuid.uuid4()), + context_id=self._context_id, + task_id=self._task_id, + parts=[Part(text=text)], + ) + ) + + response_parts: list[str] = [] + + async for event in self._client.send_message(request, context=context): + if event.HasField('task'): + self._task_id = event.task.id + elif event.HasField('message'): + response_parts.append(get_message_text(event.message)) + elif event.HasField('status_update'): + if not self._task_id and event.status_update.task_id: + self._task_id = event.status_update.task_id + if event.status_update.status.state in _TERMINAL_STATES: + self._task_id = None + if event.status_update.status.HasField('message'): + response_parts.append( + get_message_text(event.status_update.status.message) + ) + elif event.HasField('artifact_update'): + response_parts.append( + get_artifact_text(event.artifact_update.artifact) + ) + + return delimiter.join(response_parts) + + async def close(self) -> None: + """Closes the underlying client.""" + await self._client.close() diff --git a/tests/client/test_text_client.py b/tests/client/test_text_client.py new file mode 100644 index 000000000..380033a3c --- /dev/null +++ b/tests/client/test_text_client.py @@ -0,0 +1,227 @@ +from unittest.mock import AsyncMock + +import pytest + +from a2a.client import ( + Client, + ClientCallContext, + ClientConfig, + TextClient, + create_text_client, + minimal_agent_card, +) +from a2a.types import Part, StreamResponse, TaskState + + +@pytest.fixture +def mock_client() -> AsyncMock: + return AsyncMock(spec=Client) + + +@pytest.fixture +def text_client(mock_client: AsyncMock) -> TextClient: + return TextClient(mock_client) + + +def test_client_property( + text_client: TextClient, mock_client: AsyncMock +) -> None: + assert text_client.client is mock_client + + +@pytest.mark.asyncio +async def test_create_client_and_wrap() -> None: + # Create a minimal card + card = minimal_agent_card(url='http://test.com', transports=['JSONRPC']) + + config = ClientConfig(supported_protocol_bindings=['JSONRPC']) + + text_client = await create_text_client(card, client_config=config) + + assert isinstance(text_client, TextClient) + assert isinstance(text_client.client, Client) + + # Clean up + await text_client.close() + + +@pytest.mark.asyncio +async def test_send_text_message( + text_client: TextClient, mock_client: AsyncMock +) -> None: + async def create_stream(*args, **kwargs): + # Event 0: task (ignored) + resp0 = StreamResponse() + resp0.task.id = 'task-1' + yield resp0 + + # Event 1: status update + resp1 = StreamResponse() + resp1.status_update.status.message.parts.append(Part(text='Hello')) + yield resp1 + + # Event 2: status update without message + resp2 = StreamResponse() + resp2.status_update.status.state = 1 + yield resp2 + + # Event 3: status update with message + resp3 = StreamResponse() + resp3.status_update.status.message.parts.append(Part(text='Processing')) + yield resp3 + + # Event 4: artifact update + resp4 = StreamResponse() + resp4.artifact_update.artifact.parts.append(Part(text='World!')) + yield resp4 + + mock_client.send_message.return_value = create_stream() + + response = await text_client.send_text_message('Hi') + + assert response == 'Hello Processing World!' + mock_client.send_message.assert_called_once() + args, _ = mock_client.send_message.call_args + request = args[0] + assert request.message.parts[0].text == 'Hi' + + +@pytest.mark.asyncio +async def test_send_text_message( + text_client: TextClient, mock_client: AsyncMock +) -> None: + async def create_stream(*args, **kwargs): + resp1 = StreamResponse() + resp1.message.parts.append(Part(text='Hello')) + yield resp1 + + mock_client.send_message.return_value = create_stream() + response = await text_client.send_text_message('Hi') + assert response == 'Hello' + + +@pytest.mark.asyncio +async def test_send_text_message_forwards_context( + text_client: TextClient, mock_client: AsyncMock +) -> None: + async def empty_stream(*args, **kwargs): + return + yield + + mock_client.send_message.return_value = empty_stream() + context = ClientCallContext() + + await text_client.send_text_message('Hi', context=context) + + _, kwargs = mock_client.send_message.call_args + assert kwargs['context'] is context + + +def test_reset_session_changes_context_id(text_client: TextClient) -> None: + # Access internal state only to verify reset behaviour, not as public API + original = text_client._context_id + text_client.reset_session() + assert text_client._context_id != original + assert text_client._task_id is None + + +@pytest.mark.asyncio +async def test_send_text_message_sets_task_id_from_task_event( + text_client: TextClient, mock_client: AsyncMock +) -> None: + async def create_stream(*args, **kwargs): + resp = StreamResponse() + resp.task.id = 'task-123' + yield resp + + mock_client.send_message.return_value = create_stream() + await text_client.send_text_message('Hi') + assert text_client._task_id == 'task-123' + + +@pytest.mark.asyncio +async def test_send_text_message_sets_task_id_from_status_update( + text_client: TextClient, mock_client: AsyncMock +) -> None: + async def create_stream(*args, **kwargs): + resp = StreamResponse() + resp.status_update.task_id = 'task-456' + resp.status_update.status.state = 1 + yield resp + + mock_client.send_message.return_value = create_stream() + await text_client.send_text_message('Hi') + assert text_client._task_id == 'task-456' + + +@pytest.mark.asyncio +async def test_session_ids_passed_in_request( + text_client: TextClient, mock_client: AsyncMock +) -> None: + async def create_stream(*args, **kwargs): + resp = StreamResponse() + resp.task.id = 'task-789' + yield resp + + mock_client.send_message.return_value = create_stream() + context_id = text_client._context_id + + await text_client.send_text_message('Hi') + + args, _ = mock_client.send_message.call_args + request = args[0] + assert request.message.context_id == context_id + assert not request.message.task_id + + # Second call carries the task_id from the first + async def create_stream2(*args, **kwargs): + return + yield + + mock_client.send_message.return_value = create_stream2() + await text_client.send_text_message('Follow up') + + args, _ = mock_client.send_message.call_args + request = args[0] + assert request.message.context_id == context_id + assert request.message.task_id == 'task-789' + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'terminal_state', + [ + TaskState.TASK_STATE_COMPLETED, + TaskState.TASK_STATE_FAILED, + TaskState.TASK_STATE_CANCELED, + TaskState.TASK_STATE_REJECTED, + ], +) +async def test_task_id_cleared_on_terminal_state( + text_client: TextClient, + mock_client: AsyncMock, + terminal_state: TaskState, +) -> None: + async def create_stream(*args, **kwargs): + resp = StreamResponse() + resp.status_update.task_id = 'task-abc' + resp.status_update.status.state = terminal_state + yield resp + + mock_client.send_message.return_value = create_stream() + await text_client.send_text_message('Hi') + assert text_client._task_id is None + + +@pytest.mark.asyncio +async def test_async_context_manager(mock_client: AsyncMock) -> None: + async with TextClient(mock_client) as client: + assert isinstance(client, TextClient) + mock_client.close.assert_not_awaited() + mock_client.close.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_close(text_client: TextClient, mock_client: AsyncMock) -> None: + await text_client.close() + mock_client.close.assert_awaited_once()