forked from a2aproject/a2a-python
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcli.py
More file actions
122 lines (97 loc) · 3.73 KB
/
cli.py
File metadata and controls
122 lines (97 loc) · 3.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import asyncio
import os
import signal
import uuid
from typing import Any
import grpc
import httpx
from a2a.client import A2ACardResolver, ClientConfig, create_client
from a2a.types import Message, Part, Role, SendMessageRequest, TaskState
from a2a.utils.message import get_message_text
async def _handle_stream(
stream: Any, current_task_id: str | None
) -> str | None:
async for event in stream:
if not current_task_id:
current_task_id = event.task.id
if event:
if event.HasField('status_update'):
state_name = TaskState.Name(event.status_update.status.state)
print(f'TaskStatusUpdate [state={state_name}]:', end=' ')
if event.status_update.status.HasField('message'):
message = event.status_update.status.message
print(get_message_text(message, delimiter=' '))
print()
if (
event.status_update.status.state
== TaskState.TASK_STATE_COMPLETED
):
current_task_id = None
print('--- Task Completed ---')
elif event.HasField('artifact_update'):
print(
f'TaskArtifactUpdate [name={event.artifact_update.artifact.name}]:',
end=' ',
)
for part in event.artifact_update.artifact.parts:
if part.text:
print(part.text, end=' ')
print()
return current_task_id
async def main() -> None:
"""Run the A2A terminal client."""
parser = argparse.ArgumentParser(description='A2A Terminal Client')
parser.add_argument(
'--url', default='http://127.0.0.1:41241', help='Agent base URL'
)
parser.add_argument(
'--transport',
default=None,
help='Preferred transport (JSONRPC, HTTP+JSON, GRPC)',
)
args = parser.parse_args()
config = ClientConfig()
if args.transport:
config.supported_protocol_bindings = [args.transport]
print(
f'Connecting to {args.url} (preferred transport: {args.transport or "Any"})'
)
async with httpx.AsyncClient() as httpx_client:
resolver = A2ACardResolver(httpx_client, args.url)
card = await resolver.get_agent_card()
print('\n✓ Agent Card Found:')
print(f' Name: {card.name}')
client = await create_client(card, client_config=config)
actual_transport = getattr(client, '_transport', client)
print(f' Picked Transport: {actual_transport.__class__.__name__}')
print('\nConnected! Send a message or type /quit to exit.')
current_task_id = None
current_context_id = str(uuid.uuid4())
while True:
try:
loop = asyncio.get_running_loop()
user_input = await loop.run_in_executor(None, input, 'You: ')
except KeyboardInterrupt:
break
if user_input.lower() in ('/quit', '/exit'):
break
if not user_input.strip():
continue
message = Message(
role=Role.ROLE_USER,
message_id=str(uuid.uuid4()),
parts=[Part(text=user_input)],
task_id=current_task_id,
context_id=current_context_id,
)
request = SendMessageRequest(message=message)
try:
stream = client.send_message(request)
current_task_id = await _handle_stream(stream, current_task_id)
except (httpx.RequestError, grpc.RpcError) as e:
print(f'Error communicating with agent: {e}')
await client.close()
if __name__ == '__main__':
signal.signal(signal.SIGINT, lambda sig, frame: os._exit(0))
asyncio.run(main())