55import httpx
66import pytest
77import pytest_asyncio
8+ from starlette .applications import Starlette
89
910from a2a .client .base_client import BaseClient
1011from a2a .client .client import ClientConfig
1112from a2a .client .client_factory import ClientFactory
1213from a2a .server .agent_execution import AgentExecutor , RequestContext
13- from a2a .server .routes .rest_routes import create_rest_routes
14- from starlette .applications import Starlette
15- from a2a .server .routes import create_jsonrpc_routes , create_agent_card_routes
1614from a2a .server .events import EventQueue
1715from a2a .server .events .in_memory_queue_manager import InMemoryQueueManager
18- from a2a .server .request_handlers import DefaultRequestHandler , GrpcHandler
16+ from a2a .server .request_handlers import GrpcHandler , LegacyRequestHandler
17+ from a2a .server .routes import create_agent_card_routes , create_jsonrpc_routes
18+ from a2a .server .routes .rest_routes import create_rest_routes
1919from a2a .server .tasks import TaskUpdater
2020from a2a .server .tasks .inmemory_task_store import InMemoryTaskStore
2121from a2a .types import (
3737 TaskState ,
3838 a2a_pb2_grpc ,
3939)
40- from a2a .utils import TransportProtocol
40+ from a2a .utils import TransportProtocol , new_task
4141from a2a .utils .errors import InvalidParamsError
4242
4343
@@ -69,7 +69,9 @@ def assert_events_match(events, expected_events):
6969 events , expected_events , strict = True
7070 ):
7171 assert event .HasField (expected_type )
72- if expected_type == 'status_update' :
72+ if expected_type == 'task' :
73+ assert event .task .status .state == expected_val
74+ elif expected_type == 'status_update' :
7375 assert event .status_update .status .state == expected_val
7476 elif expected_type == 'artifact_update' :
7577 if expected_val is not None :
@@ -83,26 +85,30 @@ def assert_events_match(events, expected_events):
8385
8486class MockAgentExecutor (AgentExecutor ):
8587 async def execute (self , context : RequestContext , event_queue : EventQueue ):
86- task_updater = TaskUpdater (
87- event_queue ,
88- context .task_id ,
89- context .context_id ,
90- )
9188 user_input = context .get_user_input ()
9289
93- is_input_required_resumption = (
94- context .current_task is not None
95- and context .current_task .status .state
96- == TaskState .TASK_STATE_INPUT_REQUIRED
97- )
98-
99- if not is_input_required_resumption :
100- await task_updater .update_status (
101- TaskState .TASK_STATE_SUBMITTED ,
102- message = task_updater .new_agent_message (
103- [Part (text = 'task submitted' )]
104- ),
90+ # Direct message response (no task created).
91+ if user_input .startswith ('Message:' ):
92+ await event_queue .enqueue_event (
93+ Message (
94+ role = Role .ROLE_AGENT ,
95+ message_id = 'direct-reply-1' ,
96+ parts = [Part (text = f'Direct reply to: { user_input } ' )],
97+ )
10598 )
99+ return
100+
101+ # Task-based response.
102+ task = context .current_task
103+ if not task :
104+ task = new_task (context .message )
105+ await event_queue .enqueue_event (task )
106+
107+ task_updater = TaskUpdater (
108+ event_queue ,
109+ task .id ,
110+ task .context_id ,
111+ )
106112
107113 await task_updater .update_status (
108114 TaskState .TASK_STATE_WORKING ,
@@ -168,7 +174,8 @@ class ClientSetup(NamedTuple):
168174@pytest .fixture
169175def base_e2e_setup (agent_card ):
170176 task_store = InMemoryTaskStore ()
171- handler = DefaultRequestHandler (
177+ # TODO(https://github.com/a2aproject/a2a-python/issues/869): Use DefaultRequestHandler once it's fixed
178+ handler = LegacyRequestHandler (
172179 agent_executor = MockAgentExecutor (),
173180 task_store = task_store ,
174181 agent_card = agent_card ,
@@ -328,7 +335,6 @@ async def test_end_to_end_send_message_blocking(transport_setups):
328335 response .task .history ,
329336 [
330337 (Role .ROLE_USER , 'Run dummy agent!' ),
331- (Role .ROLE_AGENT , 'task submitted' ),
332338 (Role .ROLE_AGENT , 'task working' ),
333339 ],
334340 )
@@ -386,20 +392,19 @@ async def test_end_to_end_send_message_streaming(transport_setups):
386392 assert_events_match (
387393 events ,
388394 [
389- ('status_update ' , TaskState .TASK_STATE_SUBMITTED ),
395+ ('task ' , TaskState .TASK_STATE_SUBMITTED ),
390396 ('status_update' , TaskState .TASK_STATE_WORKING ),
391397 ('artifact_update' , [('test-artifact' , 'artifact content' )]),
392398 ('status_update' , TaskState .TASK_STATE_COMPLETED ),
393399 ],
394400 )
395401
396- task_id = events [0 ].status_update . task_id
402+ task_id = events [0 ].task . id
397403 task = await client .get_task (request = GetTaskRequest (id = task_id ))
398404 assert_history_matches (
399405 task .history ,
400406 [
401407 (Role .ROLE_USER , 'Run dummy agent!' ),
402- (Role .ROLE_AGENT , 'task submitted' ),
403408 (Role .ROLE_AGENT , 'task working' ),
404409 ],
405410 )
@@ -423,7 +428,7 @@ async def test_end_to_end_get_task(transport_setups):
423428 )
424429 ]
425430 response = events [0 ]
426- task_id = response .status_update . task_id
431+ task_id = response .task . id
427432
428433 get_request = GetTaskRequest (id = task_id )
429434 retrieved_task = await client .get_task (request = get_request )
@@ -438,7 +443,6 @@ async def test_end_to_end_get_task(transport_setups):
438443 retrieved_task .history ,
439444 [
440445 (Role .ROLE_USER , 'Test Get Task' ),
441- (Role .ROLE_AGENT , 'task submitted' ),
442446 (Role .ROLE_AGENT , 'task working' ),
443447 ],
444448 )
@@ -465,7 +469,7 @@ async def test_end_to_end_list_tasks(transport_setups):
465469 )
466470 )
467471 )
468- expected_task_ids .append (response .status_update . task_id )
472+ expected_task_ids .append (response .task . id )
469473
470474 list_request = ListTasksRequest (page_size = page_size )
471475
@@ -514,21 +518,20 @@ async def test_end_to_end_input_required(transport_setups):
514518 assert_events_match (
515519 events ,
516520 [
517- ('status_update ' , TaskState .TASK_STATE_SUBMITTED ),
521+ ('task ' , TaskState .TASK_STATE_SUBMITTED ),
518522 ('status_update' , TaskState .TASK_STATE_WORKING ),
519523 ('status_update' , TaskState .TASK_STATE_INPUT_REQUIRED ),
520524 ],
521525 )
522526
523- task_id = events [0 ].status_update . task_id
527+ task_id = events [0 ].task . id
524528 task = await client .get_task (request = GetTaskRequest (id = task_id ))
525529
526530 assert task .status .state == TaskState .TASK_STATE_INPUT_REQUIRED
527531 assert_history_matches (
528532 task .history ,
529533 [
530534 (Role .ROLE_USER , 'Need input' ),
531- (Role .ROLE_AGENT , 'task submitted' ),
532535 (Role .ROLE_AGENT , 'task working' ),
533536 ],
534537 )
@@ -572,7 +575,6 @@ async def test_end_to_end_input_required(transport_setups):
572575 task .history ,
573576 [
574577 (Role .ROLE_USER , 'Need input' ),
575- (Role .ROLE_AGENT , 'task submitted' ),
576578 (Role .ROLE_AGENT , 'task working' ),
577579 (Role .ROLE_AGENT , 'Please provide input' ),
578580 (Role .ROLE_USER , 'Here is the input' ),
@@ -681,3 +683,78 @@ async def test_end_to_end_subscribe_validation_error(
681683 assert {e ['field' ] for e in errors } == {'id' }
682684
683685 await client .close ()
686+
687+
688+ @pytest .mark .asyncio
689+ @pytest .mark .parametrize (
690+ 'streaming' ,
691+ [
692+ pytest .param (False , id = 'blocking' ),
693+ pytest .param (True , id = 'streaming' ),
694+ ],
695+ )
696+ async def test_end_to_end_direct_message (transport_setups , streaming ):
697+ """Test that an executor can return a direct Message without creating a Task."""
698+ client = transport_setups .client
699+ client ._config .streaming = streaming
700+
701+ message_to_send = Message (
702+ role = Role .ROLE_USER ,
703+ message_id = 'msg-direct' ,
704+ parts = [Part (text = 'Message: Hello agent' )],
705+ )
706+
707+ events = [
708+ event
709+ async for event in client .send_message (
710+ request = SendMessageRequest (message = message_to_send )
711+ )
712+ ]
713+
714+ assert len (events ) == 1
715+ response = events [0 ]
716+ assert response .HasField ('message' )
717+ assert not response .HasField ('task' )
718+ assert_message_matches (
719+ response .message ,
720+ Role .ROLE_AGENT ,
721+ 'Direct reply to: Message: Hello agent' ,
722+ )
723+
724+
725+ @pytest .mark .asyncio
726+ async def test_end_to_end_direct_message_return_immediately (transport_setups ):
727+ """Test that return_immediately still returns the Message for direct replies.
728+
729+ When the executor responds with a direct Message, the response is
730+ inherently immediate -- there is no async task to defer to. The client
731+ should receive the Message regardless of the return_immediately flag.
732+ """
733+ client = transport_setups .client
734+ client ._config .streaming = False
735+
736+ message_to_send = Message (
737+ role = Role .ROLE_USER ,
738+ message_id = 'msg-direct-return-immediately' ,
739+ parts = [Part (text = 'Message: Quick question' )],
740+ )
741+ configuration = SendMessageConfiguration (return_immediately = True )
742+
743+ events = [
744+ event
745+ async for event in client .send_message (
746+ request = SendMessageRequest (
747+ message = message_to_send , configuration = configuration
748+ )
749+ )
750+ ]
751+
752+ assert len (events ) == 1
753+ response = events [0 ]
754+ assert response .HasField ('message' )
755+ assert not response .HasField ('task' )
756+ assert_message_matches (
757+ response .message ,
758+ Role .ROLE_AGENT ,
759+ 'Direct reply to: Message: Quick question' ,
760+ )
0 commit comments