Skip to content

Commit a669521

Browse files
authored
test: add more scenarios to test_end_to_end (#954)
Based on https://a2a-protocol.org/latest/specification/#312-send-streaming-message: 1. `Message` based flow. 2. Emit `Task` as a first event. # TODO: switches to the old request handler as there are known issues in the new one With a new handler failures are caused by 1. `Task` events are not streamed 2. `return_immediately` + direct message - V2 returns a phantom `Task` before the executor produces its `Message`
1 parent be4c5ff commit a669521

1 file changed

Lines changed: 112 additions & 35 deletions

File tree

tests/integration/test_end_to_end.py

Lines changed: 112 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@
55
import httpx
66
import pytest
77
import pytest_asyncio
8+
from starlette.applications import Starlette
89

910
from a2a.client.base_client import BaseClient
1011
from a2a.client.client import ClientConfig
1112
from a2a.client.client_factory import ClientFactory
1213
from a2a.server.agent_execution import AgentExecutor, RequestContext
13-
from a2a.server.routes.rest_routes import create_rest_routes
14-
from starlette.applications import Starlette
15-
from a2a.server.routes import create_jsonrpc_routes, create_agent_card_routes
1614
from a2a.server.events import EventQueue
1715
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
18-
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
16+
from a2a.server.request_handlers import GrpcHandler, LegacyRequestHandler
17+
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
18+
from a2a.server.routes.rest_routes import create_rest_routes
1919
from a2a.server.tasks import TaskUpdater
2020
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
2121
from a2a.types import (
@@ -37,7 +37,7 @@
3737
TaskState,
3838
a2a_pb2_grpc,
3939
)
40-
from a2a.utils import TransportProtocol
40+
from a2a.utils import TransportProtocol, new_task
4141
from a2a.utils.errors import InvalidParamsError
4242

4343

@@ -69,7 +69,9 @@ def assert_events_match(events, expected_events):
6969
events, expected_events, strict=True
7070
):
7171
assert event.HasField(expected_type)
72-
if expected_type == 'status_update':
72+
if expected_type == 'task':
73+
assert event.task.status.state == expected_val
74+
elif expected_type == 'status_update':
7375
assert event.status_update.status.state == expected_val
7476
elif expected_type == 'artifact_update':
7577
if expected_val is not None:
@@ -83,26 +85,30 @@ def assert_events_match(events, expected_events):
8385

8486
class MockAgentExecutor(AgentExecutor):
8587
async def execute(self, context: RequestContext, event_queue: EventQueue):
86-
task_updater = TaskUpdater(
87-
event_queue,
88-
context.task_id,
89-
context.context_id,
90-
)
9188
user_input = context.get_user_input()
9289

93-
is_input_required_resumption = (
94-
context.current_task is not None
95-
and context.current_task.status.state
96-
== TaskState.TASK_STATE_INPUT_REQUIRED
97-
)
98-
99-
if not is_input_required_resumption:
100-
await task_updater.update_status(
101-
TaskState.TASK_STATE_SUBMITTED,
102-
message=task_updater.new_agent_message(
103-
[Part(text='task submitted')]
104-
),
90+
# Direct message response (no task created).
91+
if user_input.startswith('Message:'):
92+
await event_queue.enqueue_event(
93+
Message(
94+
role=Role.ROLE_AGENT,
95+
message_id='direct-reply-1',
96+
parts=[Part(text=f'Direct reply to: {user_input}')],
97+
)
10598
)
99+
return
100+
101+
# Task-based response.
102+
task = context.current_task
103+
if not task:
104+
task = new_task(context.message)
105+
await event_queue.enqueue_event(task)
106+
107+
task_updater = TaskUpdater(
108+
event_queue,
109+
task.id,
110+
task.context_id,
111+
)
106112

107113
await task_updater.update_status(
108114
TaskState.TASK_STATE_WORKING,
@@ -168,7 +174,8 @@ class ClientSetup(NamedTuple):
168174
@pytest.fixture
169175
def base_e2e_setup(agent_card):
170176
task_store = InMemoryTaskStore()
171-
handler = DefaultRequestHandler(
177+
# TODO(https://github.com/a2aproject/a2a-python/issues/869): Use DefaultRequestHandler once it's fixed
178+
handler = LegacyRequestHandler(
172179
agent_executor=MockAgentExecutor(),
173180
task_store=task_store,
174181
agent_card=agent_card,
@@ -328,7 +335,6 @@ async def test_end_to_end_send_message_blocking(transport_setups):
328335
response.task.history,
329336
[
330337
(Role.ROLE_USER, 'Run dummy agent!'),
331-
(Role.ROLE_AGENT, 'task submitted'),
332338
(Role.ROLE_AGENT, 'task working'),
333339
],
334340
)
@@ -386,20 +392,19 @@ async def test_end_to_end_send_message_streaming(transport_setups):
386392
assert_events_match(
387393
events,
388394
[
389-
('status_update', TaskState.TASK_STATE_SUBMITTED),
395+
('task', TaskState.TASK_STATE_SUBMITTED),
390396
('status_update', TaskState.TASK_STATE_WORKING),
391397
('artifact_update', [('test-artifact', 'artifact content')]),
392398
('status_update', TaskState.TASK_STATE_COMPLETED),
393399
],
394400
)
395401

396-
task_id = events[0].status_update.task_id
402+
task_id = events[0].task.id
397403
task = await client.get_task(request=GetTaskRequest(id=task_id))
398404
assert_history_matches(
399405
task.history,
400406
[
401407
(Role.ROLE_USER, 'Run dummy agent!'),
402-
(Role.ROLE_AGENT, 'task submitted'),
403408
(Role.ROLE_AGENT, 'task working'),
404409
],
405410
)
@@ -423,7 +428,7 @@ async def test_end_to_end_get_task(transport_setups):
423428
)
424429
]
425430
response = events[0]
426-
task_id = response.status_update.task_id
431+
task_id = response.task.id
427432

428433
get_request = GetTaskRequest(id=task_id)
429434
retrieved_task = await client.get_task(request=get_request)
@@ -438,7 +443,6 @@ async def test_end_to_end_get_task(transport_setups):
438443
retrieved_task.history,
439444
[
440445
(Role.ROLE_USER, 'Test Get Task'),
441-
(Role.ROLE_AGENT, 'task submitted'),
442446
(Role.ROLE_AGENT, 'task working'),
443447
],
444448
)
@@ -465,7 +469,7 @@ async def test_end_to_end_list_tasks(transport_setups):
465469
)
466470
)
467471
)
468-
expected_task_ids.append(response.status_update.task_id)
472+
expected_task_ids.append(response.task.id)
469473

470474
list_request = ListTasksRequest(page_size=page_size)
471475

@@ -514,21 +518,20 @@ async def test_end_to_end_input_required(transport_setups):
514518
assert_events_match(
515519
events,
516520
[
517-
('status_update', TaskState.TASK_STATE_SUBMITTED),
521+
('task', TaskState.TASK_STATE_SUBMITTED),
518522
('status_update', TaskState.TASK_STATE_WORKING),
519523
('status_update', TaskState.TASK_STATE_INPUT_REQUIRED),
520524
],
521525
)
522526

523-
task_id = events[0].status_update.task_id
527+
task_id = events[0].task.id
524528
task = await client.get_task(request=GetTaskRequest(id=task_id))
525529

526530
assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED
527531
assert_history_matches(
528532
task.history,
529533
[
530534
(Role.ROLE_USER, 'Need input'),
531-
(Role.ROLE_AGENT, 'task submitted'),
532535
(Role.ROLE_AGENT, 'task working'),
533536
],
534537
)
@@ -572,7 +575,6 @@ async def test_end_to_end_input_required(transport_setups):
572575
task.history,
573576
[
574577
(Role.ROLE_USER, 'Need input'),
575-
(Role.ROLE_AGENT, 'task submitted'),
576578
(Role.ROLE_AGENT, 'task working'),
577579
(Role.ROLE_AGENT, 'Please provide input'),
578580
(Role.ROLE_USER, 'Here is the input'),
@@ -681,3 +683,78 @@ async def test_end_to_end_subscribe_validation_error(
681683
assert {e['field'] for e in errors} == {'id'}
682684

683685
await client.close()
686+
687+
688+
@pytest.mark.asyncio
689+
@pytest.mark.parametrize(
690+
'streaming',
691+
[
692+
pytest.param(False, id='blocking'),
693+
pytest.param(True, id='streaming'),
694+
],
695+
)
696+
async def test_end_to_end_direct_message(transport_setups, streaming):
697+
"""Test that an executor can return a direct Message without creating a Task."""
698+
client = transport_setups.client
699+
client._config.streaming = streaming
700+
701+
message_to_send = Message(
702+
role=Role.ROLE_USER,
703+
message_id='msg-direct',
704+
parts=[Part(text='Message: Hello agent')],
705+
)
706+
707+
events = [
708+
event
709+
async for event in client.send_message(
710+
request=SendMessageRequest(message=message_to_send)
711+
)
712+
]
713+
714+
assert len(events) == 1
715+
response = events[0]
716+
assert response.HasField('message')
717+
assert not response.HasField('task')
718+
assert_message_matches(
719+
response.message,
720+
Role.ROLE_AGENT,
721+
'Direct reply to: Message: Hello agent',
722+
)
723+
724+
725+
@pytest.mark.asyncio
726+
async def test_end_to_end_direct_message_return_immediately(transport_setups):
727+
"""Test that return_immediately still returns the Message for direct replies.
728+
729+
When the executor responds with a direct Message, the response is
730+
inherently immediate -- there is no async task to defer to. The client
731+
should receive the Message regardless of the return_immediately flag.
732+
"""
733+
client = transport_setups.client
734+
client._config.streaming = False
735+
736+
message_to_send = Message(
737+
role=Role.ROLE_USER,
738+
message_id='msg-direct-return-immediately',
739+
parts=[Part(text='Message: Quick question')],
740+
)
741+
configuration = SendMessageConfiguration(return_immediately=True)
742+
743+
events = [
744+
event
745+
async for event in client.send_message(
746+
request=SendMessageRequest(
747+
message=message_to_send, configuration=configuration
748+
)
749+
)
750+
]
751+
752+
assert len(events) == 1
753+
response = events[0]
754+
assert response.HasField('message')
755+
assert not response.HasField('task')
756+
assert_message_matches(
757+
response.message,
758+
Role.ROLE_AGENT,
759+
'Direct reply to: Message: Quick question',
760+
)

0 commit comments

Comments
 (0)