33import base64
44import logging
55import os
6+ import signal
67import uuid
78
89import grpc
910import httpx
1011import uvicorn
1112
1213from fastapi import FastAPI
14+ from typing import Any
1315
1416from pyproto import instruction_pb2
1517
16- from a2a .client import ClientConfig , create_client
18+ from a2a .client import Client , ClientConfig , create_client
19+ from a2a .client .errors import A2AClientError
1720from a2a .compat .v0_3 import a2a_v0_3_pb2_grpc
1821from a2a .compat .v0_3 .grpc_handler import CompatGrpcHandler
1922from a2a .server .agent_execution import AgentExecutor , RequestContext
3639 AgentCapabilities ,
3740 AgentCard ,
3841 AgentInterface ,
42+ CancelTaskRequest ,
3943 Message ,
4044 Part ,
4145 SendMessageRequest ,
46+ SubscribeToTaskRequest ,
4247 Task ,
4348 TaskState ,
4449 TaskStatus ,
@@ -98,6 +103,95 @@ def extract_instruction(
98103 return None
99104
100105
106+ def _extract_text_from_event (event : Any ) -> list [str ]:
107+ """Extracts text parts from an event's message."""
108+ if isinstance (event , tuple ):
109+ results = []
110+ for item in event :
111+ results .extend (_extract_text_from_event (item ))
112+ return results
113+
114+ message = None
115+ if hasattr (event , 'HasField' ):
116+ if event .HasField ('message' ):
117+ message = event .message
118+ elif event .HasField ('task' ) and event .task .status .HasField ('message' ):
119+ message = event .task .status .message
120+ elif event .HasField (
121+ 'status_update'
122+ ) and event .status_update .status .HasField ('message' ):
123+ message = event .status_update .status .message
124+
125+ results = []
126+ if message :
127+ results .extend (part .text for part in message .parts if part .text )
128+ return results
129+
130+
131+ async def _handle_call_agent_with_resubscribe (
132+ client : Client , request : SendMessageRequest
133+ ) -> list [str ]:
134+ """Handles the send-disconnect-resubscribe flow."""
135+ results = []
136+ logger .info ('Executing re-subscribe behavior' )
137+ agen = client .send_message (request )
138+ task_id = None
139+
140+ async for event in agen :
141+ logger .info ('Event before disconnect: %s' , event )
142+ if event .HasField ('task' ):
143+ task_id = event .task .id
144+ elif event .HasField ('status_update' ):
145+ task_id = event .status_update .task_id
146+ break
147+
148+ await agen .aclose ()
149+ logger .info ('Disconnected from task %s. Now re-subscribing.' , task_id )
150+
151+ resub_agen = client .subscribe (SubscribeToTaskRequest (id = task_id ))
152+
153+ task_obj = None
154+ finished = False
155+ async for event in resub_agen :
156+ logger .info ('Event after re-subscribe: %s' , event )
157+ if hasattr (event , 'HasField' ) and event .HasField ('task' ):
158+ task_obj = event .task
159+
160+ extracted_text = _extract_text_from_event (event )
161+ for text in extracted_text :
162+ processed_text = text .replace ('task-finished' , '' )
163+ results .append (processed_text )
164+ if any ('task-finished' in text for text in extracted_text ):
165+ logger .info (
166+ 'Received task-finished after re-subscribe, breaking loop.'
167+ )
168+ finished = True
169+ break
170+
171+ if not results and task_obj and hasattr (task_obj , 'history' ):
172+ logger .info ('Results empty after loop, reading from history.' )
173+ for msg in task_obj .history :
174+ # Check stringified role to support protobuf enums (2 for ROLE_AGENT in v0.3 and v1.0)
175+ # as well as string descriptors from dict/JSON forms.
176+ if str (msg .role ) in {'2' , 'ROLE_AGENT' , 'agent' }:
177+ results .extend (
178+ part .text .replace ('task-finished' , '' )
179+ for part in msg .parts
180+ if part .text
181+ )
182+
183+ if not finished :
184+ logger .info ('Canceling task %s after retrieval.' , task_id )
185+ try :
186+ await client .cancel_task (CancelTaskRequest (id = task_id ))
187+ logger .info ('Task cancelled successfully: %s' , task_id )
188+ except A2AClientError :
189+ logger .exception ('Failed to cancel task %s' , task_id )
190+ raise
191+
192+ return results
193+
194+
101195def wrap_instruction_to_request (inst : instruction_pb2 .Instruction ) -> Message :
102196 """Wraps an Instruction proto into an A2A Message."""
103197 inst_bytes = inst .SerializeToString ()
@@ -129,18 +223,22 @@ async def handle_call_agent(
129223 'GRPC' : TransportProtocol .GRPC ,
130224 }
131225
132- selected_transport = transport_map .get (call .transport .upper ())
226+ selected_transport = transport_map .get (
227+ call .transport .upper (), TransportProtocol .JSONRPC
228+ )
133229 if selected_transport is None :
134230 raise ValueError (f'Unsupported transport: { call .transport } ' )
135231
136232 config = ClientConfig ()
137- config .httpx_client = httpx .AsyncClient (timeout = 30.0 )
138233 config .grpc_channel_factory = grpc .aio .insecure_channel
139234 config .supported_protocol_bindings = [selected_transport ]
140235 config .streaming = call .streaming or (
141236 selected_transport == TransportProtocol .GRPC
142237 )
143238
239+ if call .HasField ('resubscribe' ) and not config .streaming :
240+ raise ValueError ('Re-subscription requires streaming to be enabled' )
241+
144242 if call .HasField ('push_notification' ):
145243 url = call .push_notification .url
146244 if not url :
@@ -152,44 +250,45 @@ async def handle_call_agent(
152250 token = 'itk-token' , # noqa: S106
153251 )
154252
155- try :
156- client = await create_client (
157- call .agent_card_uri ,
158- client_config = config ,
159- )
253+ async with httpx .AsyncClient (timeout = 30.0 ) as httpx_client :
254+ config .httpx_client = httpx_client
255+ try :
256+ client = await create_client (
257+ call .agent_card_uri ,
258+ client_config = config ,
259+ )
160260
161- # Wrap nested instruction
162- nested_msg = wrap_instruction_to_request (call .instruction )
163- request = SendMessageRequest (message = nested_msg )
261+ # Wrap nested instruction
262+ nested_msg = wrap_instruction_to_request (call .instruction )
263+ request = SendMessageRequest (message = nested_msg )
164264
165- results = []
166- async for event in client .send_message (request ):
167- # Event is streaming response and task
168- logger .info ('Event: %s' , event )
169- stream_resp = event
170-
171- message = None
172- if stream_resp .HasField ('message' ):
173- message = stream_resp .message
174- elif stream_resp .HasField (
175- 'task'
176- ) and stream_resp .task .status .HasField ('message' ):
177- message = stream_resp .task .status .message
178- elif stream_resp .HasField (
179- 'status_update'
180- ) and stream_resp .status_update .status .HasField ('message' ):
181- message = stream_resp .status_update .status .message
182-
183- if message :
184- results .extend (part .text for part in message .parts if part .text )
185-
186- except Exception as e :
187- logger .exception ('Failed to call outbound agent' )
188- raise RuntimeError (
189- f'Outbound call to { call .agent_card_uri } failed: { e !s} '
190- ) from e
191- else :
192- return results
265+ results = []
266+
267+ if call .HasField ('resubscribe' ):
268+ results .extend (
269+ await _handle_call_agent_with_resubscribe (client , request )
270+ )
271+ else :
272+ async for event in client .send_message (request ):
273+ logger .info ('Event: %s' , event )
274+ results .extend (_extract_text_from_event (event ))
275+
276+ except Exception as e :
277+ logger .exception ('Failed to call outbound agent' )
278+ raise RuntimeError (
279+ f'Outbound call to { call .agent_card_uri } failed: { e !s} '
280+ ) from e
281+ else :
282+ return results
283+
284+
285+ def _should_hold (inst : instruction_pb2 .Instruction ) -> bool :
286+ """Recursively checks if any part of the instruction requests holding the task."""
287+ if inst .HasField ('return_response' ) and inst .return_response .hold_task :
288+ return True
289+ if inst .HasField ('steps' ):
290+ return any (_should_hold (step ) for step in inst .steps .instructions )
291+ return False
193292
194293
195294async def handle_instruction (
@@ -245,23 +344,58 @@ async def execute(
245344 )
246345 return
247346
347+ should_hold_task = _should_hold (instruction )
348+
248349 try :
249350 logger .info ('Instruction: %s' , instruction )
250351 results = await handle_instruction (instruction )
352+
251353 response_text = '\n ' .join (results )
252354 logger .info ('Response: %s' , response_text )
253- await task_updater .update_status (
254- TaskState .TASK_STATE_COMPLETED ,
255- message = task_updater .new_agent_message (
256- [Part (text = response_text )]
257- ),
258- )
259- logger .info ('Task %s completed' , context .task_id )
260- except Exception as e :
355+
356+ if should_hold_task :
357+ logger .info ('Holding task %s as requested' , context .task_id )
358+ # Emitted event: response + task-finished
359+ logger .info (
360+ 'Emitting response and task-finished for held task %s' ,
361+ context .task_id ,
362+ )
363+ await task_updater .update_status (
364+ TaskState .TASK_STATE_WORKING ,
365+ message = task_updater .new_agent_message (
366+ [Part (text = response_text + '\n ' + 'task-finished' )]
367+ ),
368+ )
369+ await asyncio .sleep (2 )
370+
371+ # Continue emitting "task-finished" every 2 seconds
372+ try :
373+ while True :
374+ logger .info (
375+ 'Emitting periodic status update for held task %s' ,
376+ context .task_id ,
377+ )
378+ await task_updater .update_status (
379+ TaskState .TASK_STATE_WORKING ,
380+ message = None ,
381+ )
382+ await asyncio .sleep (2 )
383+ except asyncio .CancelledError :
384+ logger .info ('Task %s cancelled' , context .task_id )
385+ return
386+ else :
387+ await task_updater .update_status (
388+ TaskState .TASK_STATE_COMPLETED ,
389+ message = task_updater .new_agent_message (
390+ [Part (text = response_text )]
391+ ),
392+ )
393+ logger .info ('Task %s completed' , context .task_id )
394+ except Exception :
261395 logger .exception ('Error during instruction handling' )
262396 await task_updater .update_status (
263397 TaskState .TASK_STATE_FAILED ,
264- message = task_updater . new_agent_message ([ Part ( text = str ( e ))]) ,
398+ message = None ,
265399 )
266400
267401 async def cancel (
@@ -325,18 +459,17 @@ async def main_async(http_port: int, grpc_port: int) -> None:
325459 name = 'ITK v10 Agent' ,
326460 description = 'Python agent using SDK 1.0.' ,
327461 version = '1.0.0' ,
328- capabilities = AgentCapabilities (
329- streaming = True , push_notifications = True , extended_agent_card = True
330- ),
462+ capabilities = AgentCapabilities (streaming = True ),
331463 default_input_modes = ['text/plain' ],
332464 default_output_modes = ['text/plain' ],
333465 supported_interfaces = interfaces ,
334466 )
335467
336468 task_store = InMemoryTaskStore ()
337469 push_config_store = InMemoryPushNotificationConfigStore ()
470+ httpx_client = httpx .AsyncClient ()
338471 push_sender = BasePushNotificationSender (
339- httpx_client = httpx . AsyncClient () ,
472+ httpx_client = httpx_client ,
340473 config_store = push_config_store ,
341474 )
342475
@@ -400,6 +533,18 @@ async def main_async(http_port: int, grpc_port: int) -> None:
400533 )
401534 uvicorn_server = uvicorn .Server (config )
402535
536+ # Signal handling
537+ loop = asyncio .get_running_loop ()
538+
539+ async def shutdown () -> None :
540+ logger .info ('Shutting down...' )
541+ uvicorn_server .should_exit = True
542+ await server .stop (5 )
543+ await httpx_client .aclose ()
544+
545+ for sig in (signal .SIGINT , signal .SIGTERM ):
546+ loop .add_signal_handler (sig , lambda : asyncio .create_task (shutdown ()))
547+
403548 await uvicorn_server .serve ()
404549
405550
0 commit comments