diff --git a/.vscode/launch.json b/.vscode/launch.json index 5c19f4812..c3a178184 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -5,37 +5,14 @@ "name": "Debug HelloWorld Agent", "type": "debugpy", "request": "launch", - "program": "${workspaceFolder}/examples/helloworld/__main__.py", + "program": "${workspaceFolder}/samples/hello_world_agent.py", "console": "integratedTerminal", "justMyCode": false, + "python": "${workspaceFolder}/.venv/bin/python3", + "cwd": "${workspaceFolder}", "env": { - "PYTHONPATH": "${workspaceFolder}" - }, - "cwd": "${workspaceFolder}/examples/helloworld", - "args": [ - "--host", - "localhost", - "--port", - "9999" - ] - }, - { - "name": "Debug Currency Agent", - "type": "debugpy", - "request": "launch", - "program": "${workspaceFolder}/examples/langgraph/__main__.py", - "console": "integratedTerminal", - "justMyCode": false, - "env": { - "PYTHONPATH": "${workspaceFolder}" - }, - "cwd": "${workspaceFolder}/examples/langgraph", - "args": [ - "--host", - "localhost", - "--port", - "10000" - ] + "PYTHONPATH": "${workspaceFolder}/src" + } }, { "name": "Pytest All", diff --git a/samples/README.md b/samples/README.md new file mode 100644 index 000000000..f9129c38a --- /dev/null +++ b/samples/README.md @@ -0,0 +1,60 @@ +# A2A Python SDK — Samples + +This directory contains runnable examples demonstrating how to build and interact with an A2A-compliant agent using the Python SDK. + +## Contents + +| File | Role | Description | +|---|---|---| +| `hello_world_agent.py` | **Server** | A2A agent server | +| `cli.py` | **Client** | Interactive terminal client | + +The samples are designed to work together out of the box: the agent listens on `http://127.0.0.1:41241`, which is the default URL used by the client. +--- + +## `hello_world_agent.py` — Agent Server + +Implements an A2A agent that responds to simple greeting messages (e.g., "hello", "how are you", "bye") with text replies, simulating a 1-second processing delay. + +Demonstrates: +- Subclassing `AgentExecutor` and implementing `execute()` / `cancel()` +- Publishing streaming status updates and artifacts via `TaskUpdater` +- Exposing all three transports in both protocol versions (v1.0 and v0.3 compat) simultaneously: + - **JSON-RPC** (v1.0 and v0.3) at `http://127.0.0.1:41241/a2a/jsonrpc` + - **HTTP+JSON (REST)** (v1.0 and v0.3) at `http://127.0.0.1:41241/a2a/rest` + - **gRPC v1.0** on port `50051` + - **gRPC v0.3 (compat)** on port `50052` +- Serving the agent card at `http://127.0.0.1:41241/.well-known/agent-card.json` + +**Run:** + +```bash +uv run python samples/hello_world_agent.py +``` + +--- + +## `cli.py` — Client + +An interactive terminal client with full visibility into the streaming event flow. Each `TaskStatusUpdate` and `TaskArtifactUpdate` event is printed as it arrives. + +Features: +- Transport selection via `--transport` flag (`JSONRPC`, `HTTP+JSON`, `GRPC`) +- Session management (`context_id` persisted across messages, `task_id` per task) +- Graceful error handling for HTTP and gRPC failures + +**Run:** + +```bash +# Connect to the local hello_world_agent (default): +uv run python samples/cli.py + +# Connect to a different URL, using gRPC: +uv run python samples/cli.py --url http://192.168.1.10:41241 --transport GRPC +``` + +Then type a message like `hello` and press Enter. + +Type `/quit` or `/exit` to stop, or press `Ctrl+C`. + + diff --git a/samples/cli.py b/samples/cli.py index 8515fd5a9..46e009516 100644 --- a/samples/cli.py +++ b/samples/cli.py @@ -11,43 +11,52 @@ from a2a.client import A2ACardResolver, ClientConfig, create_client from a2a.types import Message, Part, Role, SendMessageRequest, TaskState +from a2a.utils import get_artifact_text, get_message_text async def _handle_stream( stream: Any, current_task_id: str | None ) -> str | None: - async for event, task in stream: - if not task: + async for event in stream: + if event.HasField('message'): + print(f'Message: {get_message_text(event.message, delimiter=" ")}') continue - if not current_task_id: - current_task_id = task.id - - if event: - if event.HasField('status_update'): - state_name = TaskState.Name(event.status_update.status.state) - print(f'TaskStatusUpdate [state={state_name}]:', end=' ') - if event.status_update.status.HasField('message'): - for part in event.status_update.status.message.parts: - if part.text: - print(part.text, end=' ') - print() - - if ( - event.status_update.status.state - == TaskState.TASK_STATE_COMPLETED - ): - current_task_id = None - print('--- Task Completed ---') + if not current_task_id: + # V2 handler emits Task or Message first. + if event.HasField('task'): + current_task_id = event.task.id + state_name = TaskState.Name(event.task.status.state) + print(f'Task [state={state_name}]') + # Legacy handler might not send a leading Task event. + elif event.HasField('status_update'): + current_task_id = event.status_update.task_id elif event.HasField('artifact_update'): - print( - f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', - end=' ', - ) - for part in event.artifact_update.artifact.parts: - if part.text: - print(part.text, end=' ') + current_task_id = event.artifact_update.artifact.task_id + + if event.HasField('status_update'): + state_name = TaskState.Name(event.status_update.status.state) + print(f'TaskStatusUpdate [state={state_name}]:', end=' ') + if event.status_update.status.HasField('message'): + message = event.status_update.status.message + print(get_message_text(message, delimiter=' ')) + else: print() + if ( + event.status_update.status.state + == TaskState.TASK_STATE_COMPLETED + ): + current_task_id = None + print('--- Task Completed ---') + + elif event.HasField('artifact_update'): + print( + f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', + end=' ', + ) + print( + get_artifact_text(event.artifact_update.artifact, delimiter=' ') + ) return current_task_id @@ -68,6 +77,8 @@ async def main() -> None: config = ClientConfig() if args.transport: config.supported_protocol_bindings = [args.transport] + if args.transport == 'GRPC': + config.grpc_channel_factory = grpc.aio.insecure_channel print( f'Connecting to {args.url} (preferred transport: {args.transport or "Any"})' diff --git a/src/a2a/server/request_handlers/default_request_handler_v2.py b/src/a2a/server/request_handlers/default_request_handler_v2.py index 1a8464687..55e234605 100644 --- a/src/a2a/server/request_handlers/default_request_handler_v2.py +++ b/src/a2a/server/request_handlers/default_request_handler_v2.py @@ -37,6 +37,8 @@ SubscribeToTaskRequest, Task, TaskPushNotificationConfig, + TaskState, + TaskStatus, TaskStatusUpdateEvent, ) from a2a.utils.errors import ( @@ -302,16 +304,35 @@ async def on_message_send_stream( # noqa: D102 params: SendMessageRequest, context: ServerCallContext, ) -> AsyncGenerator[Event, None]: + is_new_task = not params.message.task_id + active_task, request_context = await self._setup_active_task( params, context ) - task_id = cast('str', request_context.task_id) + context_id = cast('str', request_context.context_id) + first_event = True async for event in active_task.subscribe( request=request_context, include_initial_task=False, ): + if ( + first_event + and is_new_task + and not isinstance(event, (Task, Message)) + ): + # Agent didn't emit a Task/Message first. + # The stream MUST begin with a Task or Message. + submitted_task = Task( + id=task_id, + context_id=context_id, + status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), + history=[params.message], + ) + yield apply_history_length(submitted_task, params.configuration) + first_event = False + if isinstance(event, Task): self._validate_task_id_match(task_id, event.id) yield apply_history_length(event, params.configuration) diff --git a/tests/integration/test_scenarios.py b/tests/integration/test_scenarios.py index cee15bfcb..8233f207d 100644 --- a/tests/integration/test_scenarios.py +++ b/tests/integration/test_scenarios.py @@ -261,10 +261,21 @@ async def cancel( event async for event in client.send_message(SendMessageRequest(message=msg)) ] - assert [event.status_update.status.state for event in events] == [ - TaskState.TASK_STATE_WORKING, - TaskState.TASK_STATE_COMPLETED, - ] + if use_legacy: + # Legacy handler streams events as-is (no Task(SUBMITTED) injection). + assert [event.status_update.status.state for event in events] == [ + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] + else: + # V2 handler injects Task(SUBMITTED) first per A2A spec §3.1.2. + assert events[0].HasField('task'), ( + 'First streaming event must be a Task or Message' + ) + assert [event.status_update.status.state for event in events[1:]] == [ + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] # Scenario 5: Re-subscribing to a finished task @@ -374,15 +385,32 @@ async def cancel( configuration=SendMessageConfiguration(return_immediately=False), ) ) - (event,) = [event async for event in it] + events = [event async for event in it] if streaming: - assert event.HasField('status_update') - task_id = event.status_update.task_id - assert ( - event.status_update.status.state == TaskState.TASK_STATE_COMPLETED - ) + if use_legacy: + # Legacy streams events as-is: just the status_update(COMPLETED). + (event,) = events + assert event.HasField('status_update') + task_id = event.status_update.task_id + assert ( + event.status_update.status.state + == TaskState.TASK_STATE_COMPLETED + ) + else: + # V2 injects Task(SUBMITTED) first per A2A spec §3.1.2. + assert len(events) == 2 + assert events[0].HasField('task'), ( + 'First streaming event must be a Task or Message' + ) + task_id = events[0].task.id + assert events[1].HasField('status_update') + assert ( + events[1].status_update.status.state + == TaskState.TASK_STATE_COMPLETED + ) else: + (event,) = events assert event.HasField('task') task_id = event.task.id assert event.task.status.state == TaskState.TASK_STATE_COMPLETED @@ -498,8 +526,23 @@ async def cancel( tasks = [] if streaming: - res = await it.__anext__() - assert res.status_update.status.state == TaskState.TASK_STATE_WORKING + if use_legacy: + # Legacy streams events as-is; first event is the WORKING status update. + first = await it.__anext__() + assert ( + first.status_update.status.state == TaskState.TASK_STATE_WORKING + ) + else: + # V2 injects Task(SUBMITTED) first per A2A spec §3.1.2. + first = await it.__anext__() + assert first.HasField('task'), ( + 'First streaming event must be a Task or Message' + ) + second = await it.__anext__() + assert ( + second.status_update.status.state + == TaskState.TASK_STATE_WORKING + ) continue_event.set() else: @@ -1082,10 +1125,19 @@ async def cancel( states = [get_state(event) async for event in it] if streaming: - assert states == [ - TaskState.TASK_STATE_WORKING, - TaskState.TASK_STATE_COMPLETED, - ] + if use_legacy: + # Legacy streams events as-is (no Task(SUBMITTED) injection). + assert states == [ + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] + else: + # V2 injects Task(SUBMITTED) first per A2A spec §3.1.2. + assert states == [ + TaskState.TASK_STATE_SUBMITTED, + TaskState.TASK_STATE_WORKING, + TaskState.TASK_STATE_COMPLETED, + ] else: assert states == [TaskState.TASK_STATE_WORKING] @@ -1151,11 +1203,27 @@ async def cancel( ) events1 = [event async for event in it] - assert [get_state(event) for event in events1] == [ - TaskState.TASK_STATE_INPUT_REQUIRED, - ] - task_id = events1[0].status_update.task_id - context_id = events1[0].status_update.context_id + if streaming and not use_legacy: + # V2 injects Task(SUBMITTED) first per A2A spec §3.1.2. + assert [get_state(event) for event in events1] == [ + TaskState.TASK_STATE_SUBMITTED, + TaskState.TASK_STATE_INPUT_REQUIRED, + ] + task_id = events1[0].task.id + context_id = events1[0].task.context_id + elif streaming and use_legacy: + # Legacy streams events as-is; first event is the INPUT_REQUIRED status update. + assert [get_state(event) for event in events1] == [ + TaskState.TASK_STATE_INPUT_REQUIRED, + ] + task_id = events1[0].status_update.task_id + context_id = events1[0].status_update.context_id + else: + assert [get_state(event) for event in events1] == [ + TaskState.TASK_STATE_INPUT_REQUIRED, + ] + task_id = events1[0].task.id + context_id = events1[0].task.context_id # Now send another message to resume msg2 = Message( @@ -1240,19 +1308,38 @@ async def cancel( ) if streaming: - event1 = await asyncio.wait_for(it.__anext__(), timeout=1.0) - assert get_state(event1) == TaskState.TASK_STATE_WORKING + if use_legacy: + # Legacy streams events as-is: WORKING → AUTH_REQUIRED → COMPLETED. + event1 = await asyncio.wait_for(it.__anext__(), timeout=1.0) + assert get_state(event1) == TaskState.TASK_STATE_WORKING - event2 = await asyncio.wait_for(it.__anext__(), timeout=1.0) - assert get_state(event2) == TaskState.TASK_STATE_AUTH_REQUIRED + event2 = await asyncio.wait_for(it.__anext__(), timeout=1.0) + assert get_state(event2) == TaskState.TASK_STATE_AUTH_REQUIRED - task_id = event2.status_update.task_id + task_id = event2.status_update.task_id - side_channel_event.set() + side_channel_event.set() + + (event3,) = [event async for event in it] + assert get_state(event3) == TaskState.TASK_STATE_COMPLETED + else: + # V2 injects Task(SUBMITTED) first per A2A spec §3.1.2. + # Full sequence: Task(SUBMITTED) → WORKING → AUTH_REQUIRED → COMPLETED. + event1 = await asyncio.wait_for(it.__anext__(), timeout=1.0) + assert get_state(event1) == TaskState.TASK_STATE_SUBMITTED + + event2 = await asyncio.wait_for(it.__anext__(), timeout=1.0) + assert get_state(event2) == TaskState.TASK_STATE_WORKING + + event3 = await asyncio.wait_for(it.__anext__(), timeout=1.0) + assert get_state(event3) == TaskState.TASK_STATE_AUTH_REQUIRED + + task_id = event3.status_update.task_id + + side_channel_event.set() - # Remaining event. - (event3,) = [event async for event in it] - assert get_state(event3) == TaskState.TASK_STATE_COMPLETED + (event4,) = [event async for event in it] + assert get_state(event4) == TaskState.TASK_STATE_COMPLETED else: (event,) = [event async for event in it] assert get_state(event) == TaskState.TASK_STATE_AUTH_REQUIRED diff --git a/tests/server/request_handlers/test_default_request_handler_v2.py b/tests/server/request_handlers/test_default_request_handler_v2.py index d48b82461..6b0eb824d 100644 --- a/tests/server/request_handlers/test_default_request_handler_v2.py +++ b/tests/server/request_handlers/test_default_request_handler_v2.py @@ -560,14 +560,14 @@ async def consume_stream(): message_params, create_server_call_context() ): events.append(event) - if len(events) >= 3: + if len(events) >= 4: break return events start = time.perf_counter() events = await consume_stream() elapsed = time.perf_counter() - start - assert len(events) == 3 + assert len(events) == 4 assert elapsed < 0.5 texts = [p.text for e in events for p in e.status.message.parts] assert texts == ['Event 0', 'Event 1', 'Event 2']