|
11 | 11 |
|
12 | 12 | from a2a.client import A2ACardResolver, ClientConfig, create_client |
13 | 13 | from a2a.types import Message, Part, Role, SendMessageRequest, TaskState |
14 | | -from a2a.utils.message import get_message_text |
| 14 | +from a2a.utils import get_artifact_text, get_message_text |
| 15 | +from a2a.utils.agent_card import display_agent_card |
15 | 16 |
|
16 | 17 |
|
17 | 18 | async def _handle_stream( |
18 | 19 | stream: Any, current_task_id: str | None |
19 | 20 | ) -> str | None: |
20 | 21 | async for event in stream: |
| 22 | + if event.HasField('message'): |
| 23 | + print('Message:', get_message_text(event.message, delimiter=' ')) |
| 24 | + return None |
| 25 | + |
21 | 26 | if not current_task_id: |
22 | | - current_task_id = event.task.id |
23 | | - if event: |
24 | | - if event.HasField('status_update'): |
25 | | - state_name = TaskState.Name(event.status_update.status.state) |
26 | | - print(f'TaskStatusUpdate [state={state_name}]:', end=' ') |
27 | | - if event.status_update.status.HasField('message'): |
28 | | - message = event.status_update.status.message |
29 | | - print(get_message_text(message, delimiter=' ')) |
30 | | - print() |
31 | | - |
32 | | - if ( |
33 | | - event.status_update.status.state |
34 | | - == TaskState.TASK_STATE_COMPLETED |
35 | | - ): |
36 | | - current_task_id = None |
37 | | - print('--- Task Completed ---') |
38 | | - |
39 | | - elif event.HasField('artifact_update'): |
40 | | - print( |
41 | | - f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', |
42 | | - end=' ', |
| 27 | + if event.HasField('task'): |
| 28 | + current_task_id = event.task.id |
| 29 | + print('--- Task Started ---') |
| 30 | + print(f'Task [state={TaskState.Name(event.task.status.state)}]') |
| 31 | + else: |
| 32 | + raise ValueError(f'Unexpected first event: {event}') |
| 33 | + |
| 34 | + if event.HasField('status_update'): |
| 35 | + state_name = TaskState.Name(event.status_update.status.state) |
| 36 | + message_text = ( |
| 37 | + ': ' |
| 38 | + + get_message_text( |
| 39 | + event.status_update.status.message, delimiter=' ' |
43 | 40 | ) |
44 | | - for part in event.artifact_update.artifact.parts: |
45 | | - if part.text: |
46 | | - print(part.text, end=' ') |
47 | | - print() |
48 | | - |
| 41 | + if event.status_update.status.HasField('message') |
| 42 | + else '' |
| 43 | + ) |
| 44 | + print(f'TaskStatusUpdate [state={state_name}]{message_text}') |
| 45 | + if state_name in ( |
| 46 | + 'TASK_STATE_COMPLETED', |
| 47 | + 'TASK_STATE_FAILED', |
| 48 | + 'TASK_STATE_CANCELED', |
| 49 | + 'TASK_STATE_REJECTED', |
| 50 | + ): |
| 51 | + current_task_id = None |
| 52 | + print('--- Task Finished ---') |
| 53 | + elif event.HasField('artifact_update'): |
| 54 | + print( |
| 55 | + f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:', |
| 56 | + get_artifact_text( |
| 57 | + event.artifact_update.artifact, delimiter=' ' |
| 58 | + ), |
| 59 | + ) |
49 | 60 | return current_task_id |
50 | 61 |
|
51 | 62 |
|
@@ -76,7 +87,7 @@ async def main() -> None: |
76 | 87 | resolver = A2ACardResolver(httpx_client, args.url) |
77 | 88 | card = await resolver.get_agent_card() |
78 | 89 | print('\n✓ Agent Card Found:') |
79 | | - print(f' Name: {card.name}') |
| 90 | + display_agent_card(card) |
80 | 91 |
|
81 | 92 | client = await create_client(card, client_config=config) |
82 | 93 |
|
|
0 commit comments