Skip to content

Commit c8f3df3

Browse files
kdziedzic70Krzysztof Dziedzic
andauthored
test: setup itk resubscribe tests (#1031)
# Description PR adjusts itk tests with resubscribe behavior check Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Follow the [`CONTRIBUTING` Guide](https://github.com/a2aproject/a2a-python/blob/main/CONTRIBUTING.md). - [ ] Make your Pull Request title in the <https://www.conventionalcommits.org/> specification. - Important Prefixes for [release-please](https://github.com/googleapis/release-please): - `fix:` which represents bug fixes, and correlates to a [SemVer](https://semver.org/) patch. - `feat:` represents a new feature, and correlates to a SemVer minor. - `feat!:`, or `fix!:`, `refactor!:`, etc., which represent a breaking change (indicated by the `!`) and will result in a SemVer major. - [ ] Ensure the tests and linter pass (Run `bash scripts/format.sh` from the repository root to format) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕 Co-authored-by: Krzysztof Dziedzic <dziedzick@google.com>
1 parent e65e758 commit c8f3df3

4 files changed

Lines changed: 217 additions & 54 deletions

File tree

.github/workflows/itk.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ jobs:
3131
run: bash run_itk.sh
3232
working-directory: itk
3333
env:
34-
A2A_SAMPLES_REVISION: itk-v.02-alpha
34+
A2A_SAMPLES_REVISION: itk-v.021-alpha

itk/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ You must set the `A2A_SAMPLES_REVISION` environment variable to specify which re
3636

3737
Example:
3838
```
39-
export A2A_SAMPLES_REVISION=itk-v.02-alpha
39+
export A2A_SAMPLES_REVISION=itk-v.021-alpha
4040
```
4141

4242
### 2. Execute Tests

itk/main.py

Lines changed: 197 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,20 @@
33
import base64
44
import logging
55
import os
6+
import signal
67
import uuid
78

89
import grpc
910
import httpx
1011
import uvicorn
1112

1213
from fastapi import FastAPI
14+
from typing import Any
1315

1416
from pyproto import instruction_pb2
1517

16-
from a2a.client import ClientConfig, create_client
18+
from a2a.client import Client, ClientConfig, create_client
19+
from a2a.client.errors import A2AClientError
1720
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
1821
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
1922
from a2a.server.agent_execution import AgentExecutor, RequestContext
@@ -36,9 +39,11 @@
3639
AgentCapabilities,
3740
AgentCard,
3841
AgentInterface,
42+
CancelTaskRequest,
3943
Message,
4044
Part,
4145
SendMessageRequest,
46+
SubscribeToTaskRequest,
4247
Task,
4348
TaskState,
4449
TaskStatus,
@@ -98,6 +103,95 @@ def extract_instruction(
98103
return None
99104

100105

106+
def _extract_text_from_event(event: Any) -> list[str]:
107+
"""Extracts text parts from an event's message."""
108+
if isinstance(event, tuple):
109+
results = []
110+
for item in event:
111+
results.extend(_extract_text_from_event(item))
112+
return results
113+
114+
message = None
115+
if hasattr(event, 'HasField'):
116+
if event.HasField('message'):
117+
message = event.message
118+
elif event.HasField('task') and event.task.status.HasField('message'):
119+
message = event.task.status.message
120+
elif event.HasField(
121+
'status_update'
122+
) and event.status_update.status.HasField('message'):
123+
message = event.status_update.status.message
124+
125+
results = []
126+
if message:
127+
results.extend(part.text for part in message.parts if part.text)
128+
return results
129+
130+
131+
async def _handle_call_agent_with_resubscribe(
132+
client: Client, request: SendMessageRequest
133+
) -> list[str]:
134+
"""Handles the send-disconnect-resubscribe flow."""
135+
results = []
136+
logger.info('Executing re-subscribe behavior')
137+
agen = client.send_message(request)
138+
task_id = None
139+
140+
async for event in agen:
141+
logger.info('Event before disconnect: %s', event)
142+
if event.HasField('task'):
143+
task_id = event.task.id
144+
elif event.HasField('status_update'):
145+
task_id = event.status_update.task_id
146+
break
147+
148+
await agen.aclose()
149+
logger.info('Disconnected from task %s. Now re-subscribing.', task_id)
150+
151+
resub_agen = client.subscribe(SubscribeToTaskRequest(id=task_id))
152+
153+
task_obj = None
154+
finished = False
155+
async for event in resub_agen:
156+
logger.info('Event after re-subscribe: %s', event)
157+
if hasattr(event, 'HasField') and event.HasField('task'):
158+
task_obj = event.task
159+
160+
extracted_text = _extract_text_from_event(event)
161+
for text in extracted_text:
162+
processed_text = text.replace('task-finished', '')
163+
results.append(processed_text)
164+
if any('task-finished' in text for text in extracted_text):
165+
logger.info(
166+
'Received task-finished after re-subscribe, breaking loop.'
167+
)
168+
finished = True
169+
break
170+
171+
if not results and task_obj and hasattr(task_obj, 'history'):
172+
logger.info('Results empty after loop, reading from history.')
173+
for msg in task_obj.history:
174+
# Check stringified role to support protobuf enums (2 for ROLE_AGENT in v0.3 and v1.0)
175+
# as well as string descriptors from dict/JSON forms.
176+
if str(msg.role) in {'2', 'ROLE_AGENT', 'agent'}:
177+
results.extend(
178+
part.text.replace('task-finished', '')
179+
for part in msg.parts
180+
if part.text
181+
)
182+
183+
if not finished:
184+
logger.info('Canceling task %s after retrieval.', task_id)
185+
try:
186+
await client.cancel_task(CancelTaskRequest(id=task_id))
187+
logger.info('Task cancelled successfully: %s', task_id)
188+
except A2AClientError:
189+
logger.exception('Failed to cancel task %s', task_id)
190+
raise
191+
192+
return results
193+
194+
101195
def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
102196
"""Wraps an Instruction proto into an A2A Message."""
103197
inst_bytes = inst.SerializeToString()
@@ -129,18 +223,22 @@ async def handle_call_agent(
129223
'GRPC': TransportProtocol.GRPC,
130224
}
131225

132-
selected_transport = transport_map.get(call.transport.upper())
226+
selected_transport = transport_map.get(
227+
call.transport.upper(), TransportProtocol.JSONRPC
228+
)
133229
if selected_transport is None:
134230
raise ValueError(f'Unsupported transport: {call.transport}')
135231

136232
config = ClientConfig()
137-
config.httpx_client = httpx.AsyncClient(timeout=30.0)
138233
config.grpc_channel_factory = grpc.aio.insecure_channel
139234
config.supported_protocol_bindings = [selected_transport]
140235
config.streaming = call.streaming or (
141236
selected_transport == TransportProtocol.GRPC
142237
)
143238

239+
if call.HasField('resubscribe') and not config.streaming:
240+
raise ValueError('Re-subscription requires streaming to be enabled')
241+
144242
if call.HasField('push_notification'):
145243
url = call.push_notification.url
146244
if not url:
@@ -152,44 +250,45 @@ async def handle_call_agent(
152250
token='itk-token', # noqa: S106
153251
)
154252

155-
try:
156-
client = await create_client(
157-
call.agent_card_uri,
158-
client_config=config,
159-
)
253+
async with httpx.AsyncClient(timeout=30.0) as httpx_client:
254+
config.httpx_client = httpx_client
255+
try:
256+
client = await create_client(
257+
call.agent_card_uri,
258+
client_config=config,
259+
)
160260

161-
# Wrap nested instruction
162-
nested_msg = wrap_instruction_to_request(call.instruction)
163-
request = SendMessageRequest(message=nested_msg)
261+
# Wrap nested instruction
262+
nested_msg = wrap_instruction_to_request(call.instruction)
263+
request = SendMessageRequest(message=nested_msg)
164264

165-
results = []
166-
async for event in client.send_message(request):
167-
# Event is streaming response and task
168-
logger.info('Event: %s', event)
169-
stream_resp = event
170-
171-
message = None
172-
if stream_resp.HasField('message'):
173-
message = stream_resp.message
174-
elif stream_resp.HasField(
175-
'task'
176-
) and stream_resp.task.status.HasField('message'):
177-
message = stream_resp.task.status.message
178-
elif stream_resp.HasField(
179-
'status_update'
180-
) and stream_resp.status_update.status.HasField('message'):
181-
message = stream_resp.status_update.status.message
182-
183-
if message:
184-
results.extend(part.text for part in message.parts if part.text)
185-
186-
except Exception as e:
187-
logger.exception('Failed to call outbound agent')
188-
raise RuntimeError(
189-
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
190-
) from e
191-
else:
192-
return results
265+
results = []
266+
267+
if call.HasField('resubscribe'):
268+
results.extend(
269+
await _handle_call_agent_with_resubscribe(client, request)
270+
)
271+
else:
272+
async for event in client.send_message(request):
273+
logger.info('Event: %s', event)
274+
results.extend(_extract_text_from_event(event))
275+
276+
except Exception as e:
277+
logger.exception('Failed to call outbound agent')
278+
raise RuntimeError(
279+
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
280+
) from e
281+
else:
282+
return results
283+
284+
285+
def _should_hold(inst: instruction_pb2.Instruction) -> bool:
286+
"""Recursively checks if any part of the instruction requests holding the task."""
287+
if inst.HasField('return_response') and inst.return_response.hold_task:
288+
return True
289+
if inst.HasField('steps'):
290+
return any(_should_hold(step) for step in inst.steps.instructions)
291+
return False
193292

194293

195294
async def handle_instruction(
@@ -245,23 +344,58 @@ async def execute(
245344
)
246345
return
247346

347+
should_hold_task = _should_hold(instruction)
348+
248349
try:
249350
logger.info('Instruction: %s', instruction)
250351
results = await handle_instruction(instruction)
352+
251353
response_text = '\n'.join(results)
252354
logger.info('Response: %s', response_text)
253-
await task_updater.update_status(
254-
TaskState.TASK_STATE_COMPLETED,
255-
message=task_updater.new_agent_message(
256-
[Part(text=response_text)]
257-
),
258-
)
259-
logger.info('Task %s completed', context.task_id)
260-
except Exception as e:
355+
356+
if should_hold_task:
357+
logger.info('Holding task %s as requested', context.task_id)
358+
# Emitted event: response + task-finished
359+
logger.info(
360+
'Emitting response and task-finished for held task %s',
361+
context.task_id,
362+
)
363+
await task_updater.update_status(
364+
TaskState.TASK_STATE_WORKING,
365+
message=task_updater.new_agent_message(
366+
[Part(text=response_text + '\n' + 'task-finished')]
367+
),
368+
)
369+
await asyncio.sleep(2)
370+
371+
# Continue emitting "task-finished" every 2 seconds
372+
try:
373+
while True:
374+
logger.info(
375+
'Emitting periodic status update for held task %s',
376+
context.task_id,
377+
)
378+
await task_updater.update_status(
379+
TaskState.TASK_STATE_WORKING,
380+
message=None,
381+
)
382+
await asyncio.sleep(2)
383+
except asyncio.CancelledError:
384+
logger.info('Task %s cancelled', context.task_id)
385+
return
386+
else:
387+
await task_updater.update_status(
388+
TaskState.TASK_STATE_COMPLETED,
389+
message=task_updater.new_agent_message(
390+
[Part(text=response_text)]
391+
),
392+
)
393+
logger.info('Task %s completed', context.task_id)
394+
except Exception:
261395
logger.exception('Error during instruction handling')
262396
await task_updater.update_status(
263397
TaskState.TASK_STATE_FAILED,
264-
message=task_updater.new_agent_message([Part(text=str(e))]),
398+
message=None,
265399
)
266400

267401
async def cancel(
@@ -325,18 +459,17 @@ async def main_async(http_port: int, grpc_port: int) -> None:
325459
name='ITK v10 Agent',
326460
description='Python agent using SDK 1.0.',
327461
version='1.0.0',
328-
capabilities=AgentCapabilities(
329-
streaming=True, push_notifications=True, extended_agent_card=True
330-
),
462+
capabilities=AgentCapabilities(streaming=True),
331463
default_input_modes=['text/plain'],
332464
default_output_modes=['text/plain'],
333465
supported_interfaces=interfaces,
334466
)
335467

336468
task_store = InMemoryTaskStore()
337469
push_config_store = InMemoryPushNotificationConfigStore()
470+
httpx_client = httpx.AsyncClient()
338471
push_sender = BasePushNotificationSender(
339-
httpx_client=httpx.AsyncClient(),
472+
httpx_client=httpx_client,
340473
config_store=push_config_store,
341474
)
342475

@@ -400,6 +533,18 @@ async def main_async(http_port: int, grpc_port: int) -> None:
400533
)
401534
uvicorn_server = uvicorn.Server(config)
402535

536+
# Signal handling
537+
loop = asyncio.get_running_loop()
538+
539+
async def shutdown() -> None:
540+
logger.info('Shutting down...')
541+
uvicorn_server.should_exit = True
542+
await server.stop(5)
543+
await httpx_client.aclose()
544+
545+
for sig in (signal.SIGINT, signal.SIGTERM):
546+
loop.add_signal_handler(sig, lambda: asyncio.create_task(shutdown()))
547+
403548
await uvicorn_server.serve()
404549

405550

itk/run_itk.sh

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,24 @@ RESPONSE=$(curl -s -X POST http://127.0.0.1:8000/run \
163163
"edges": ["0->1", "0->2", "1->0", "2->0"],
164164
"protocols": ["http_json"],
165165
"behavior": "push_notification"
166+
},
167+
{
168+
"name": "Resubscribe Test - JSONRPC",
169+
"sdks": ["current", "python_v10", "python_v03", "go_v10", "go_v03"],
170+
"traversal": "euler",
171+
"edges": ["0->1", "0->2", "0->3", "0->4", "1->0", "2->0", "3->0", "4->0"],
172+
"protocols": ["jsonrpc"],
173+
"streaming": true,
174+
"behavior": "resubscribe"
175+
},
176+
{
177+
"name": "Resubscribe Test - Python & Go Non-JSONRPC Protocols",
178+
"sdks": ["current", "python_v10", "python_v03", "go_v10"],
179+
"traversal": "euler",
180+
"edges": ["0->1", "0->2", "0->3", "1->0", "2->0", "3->0"],
181+
"protocols": ["grpc", "http_json"],
182+
"streaming": true,
183+
"behavior": "resubscribe"
166184
}
167185
]
168186
}')

0 commit comments

Comments
 (0)