|
1 | 1 | import uuid |
2 | 2 |
|
| 3 | +from types import TracebackType |
| 4 | + |
| 5 | +from typing_extensions import Self |
| 6 | + |
3 | 7 | from a2a.client.client import Client, ClientCallContext |
4 | | -from a2a.types import Message, Part, Role, SendMessageRequest |
| 8 | +from a2a.types import Message, Part, Role, SendMessageRequest, TaskState |
| 9 | +from a2a.utils import get_artifact_text, get_message_text |
| 10 | + |
| 11 | + |
| 12 | +_TERMINAL_STATES: frozenset[TaskState] = frozenset( |
| 13 | + { |
| 14 | + TaskState.TASK_STATE_COMPLETED, |
| 15 | + TaskState.TASK_STATE_FAILED, |
| 16 | + TaskState.TASK_STATE_CANCELED, |
| 17 | + TaskState.TASK_STATE_REJECTED, |
| 18 | + } |
| 19 | +) |
5 | 20 |
|
6 | 21 |
|
7 | 22 | class TextClient: |
8 | 23 | """A facade around Client that simplifies text-based communication. |
9 | 24 |
|
10 | 25 | Wraps an underlying Client instance and exposes a simplified interface |
11 | 26 | for sending plain-text messages and receiving aggregated text responses. |
| 27 | + Maintains session state (context_id, task_id) automatically across calls. |
12 | 28 | For full Client API access, use the underlying client directly via |
13 | 29 | the `client` property. |
14 | 30 | """ |
15 | 31 |
|
16 | 32 | def __init__(self, client: Client): |
17 | 33 | self._client = client |
| 34 | + self._context_id: str = str(uuid.uuid4()) |
| 35 | + self._task_id: str | None = None |
| 36 | + |
| 37 | + async def __aenter__(self) -> Self: |
| 38 | + """Enters the async context manager.""" |
| 39 | + return self |
| 40 | + |
| 41 | + async def __aexit__( |
| 42 | + self, |
| 43 | + exc_type: type[BaseException] | None, |
| 44 | + exc_val: BaseException | None, |
| 45 | + exc_tb: TracebackType | None, |
| 46 | + ) -> None: |
| 47 | + """Exits the async context manager and closes the client.""" |
| 48 | + await self.close() |
18 | 49 |
|
19 | 50 | @property |
20 | 51 | def client(self) -> Client: |
21 | 52 | """Returns the underlying Client instance for full API access.""" |
22 | 53 | return self._client |
23 | 54 |
|
| 55 | + def reset_session(self) -> None: |
| 56 | + """Starts a new session by generating a fresh context ID and clearing the task ID.""" |
| 57 | + self._context_id = str(uuid.uuid4()) |
| 58 | + self._task_id = None |
| 59 | + |
24 | 60 | async def send_text_message( |
25 | 61 | self, |
26 | 62 | text: str, |
27 | 63 | *, |
| 64 | + delimiter: str = ' ', |
28 | 65 | context: ClientCallContext | None = None, |
29 | 66 | ) -> str: |
30 | | - """Sends a text message and returns the aggregated text response.""" |
| 67 | + """Sends a text message and returns the aggregated text response. |
| 68 | +
|
| 69 | + Session state (context_id, task_id) is managed automatically across |
| 70 | + calls. Use reset_session() to start a new conversation. |
| 71 | +
|
| 72 | + Args: |
| 73 | + text: The plain-text message to send. |
| 74 | + delimiter: String used to join response parts. Defaults to a |
| 75 | + single space. Use '' for token-streamed responses or '\\n' |
| 76 | + for paragraph-separated chunks. |
| 77 | + context: Optional call-level context. |
| 78 | + """ |
31 | 79 | request = SendMessageRequest( |
32 | 80 | message=Message( |
33 | 81 | role=Role.ROLE_USER, |
34 | 82 | message_id=str(uuid.uuid4()), |
| 83 | + context_id=self._context_id, |
| 84 | + task_id=self._task_id, |
35 | 85 | parts=[Part(text=text)], |
36 | 86 | ) |
37 | 87 | ) |
38 | 88 |
|
39 | 89 | response_parts: list[str] = [] |
40 | 90 |
|
41 | 91 | async for event in self._client.send_message(request, context=context): |
42 | | - if event.HasField('message'): |
43 | | - response_parts.extend( |
44 | | - part.text for part in event.message.parts if part.text |
45 | | - ) |
| 92 | + if event.HasField('task'): |
| 93 | + self._task_id = event.task.id |
| 94 | + elif event.HasField('message'): |
| 95 | + response_parts.append(get_message_text(event.message)) |
46 | 96 | elif event.HasField('status_update'): |
| 97 | + if event.status_update.task_id: |
| 98 | + self._task_id = event.status_update.task_id |
| 99 | + if event.status_update.status.state in _TERMINAL_STATES: |
| 100 | + self._task_id = None |
47 | 101 | if event.status_update.status.HasField('message'): |
48 | | - response_parts.extend( |
49 | | - part.text |
50 | | - for part in event.status_update.status.message.parts |
51 | | - if part.text |
| 102 | + response_parts.append( |
| 103 | + get_message_text(event.status_update.status.message) |
52 | 104 | ) |
53 | 105 | elif event.HasField('artifact_update'): |
54 | | - response_parts.extend( |
55 | | - part.text |
56 | | - for part in event.artifact_update.artifact.parts |
57 | | - if part.text |
| 106 | + response_parts.append( |
| 107 | + get_artifact_text(event.artifact_update.artifact) |
58 | 108 | ) |
59 | 109 |
|
60 | | - return ' '.join(response_parts) |
| 110 | + return delimiter.join(response_parts) |
61 | 111 |
|
62 | 112 | async def close(self) -> None: |
63 | 113 | """Closes the underlying client.""" |
|
0 commit comments