Skip to content

Commit 7f21610

Browse files
committed
Merge remote-tracking branch 'origin/1.0-dev' into ishymko/regen-pb
2 parents 1426720 + 942f4ae commit 7f21610

24 files changed

Lines changed: 503 additions & 964 deletions

src/a2a/client/base_client.py

Lines changed: 39 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
from collections.abc import AsyncGenerator, AsyncIterator, Callable
2-
from typing import Any
32

43
from a2a.client.client import (
54
Client,
6-
ClientCallContext,
75
ClientConfig,
86
ClientEvent,
97
Consumer,
108
)
119
from a2a.client.client_task_manager import ClientTaskManager
12-
from a2a.client.middleware import ClientCallInterceptor
10+
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1311
from a2a.client.transports.base import ClientTransport
1412
from a2a.types.a2a_pb2 import (
1513
AgentCard,
@@ -22,8 +20,6 @@
2220
ListTaskPushNotificationConfigsResponse,
2321
ListTasksRequest,
2422
ListTasksResponse,
25-
Message,
26-
SendMessageConfiguration,
2723
SendMessageRequest,
2824
StreamResponse,
2925
SubscribeToTaskRequest,
@@ -50,12 +46,9 @@ def __init__(
5046

5147
async def send_message(
5248
self,
53-
request: Message,
49+
request: SendMessageRequest,
5450
*,
55-
configuration: SendMessageConfiguration | None = None,
5651
context: ClientCallContext | None = None,
57-
request_metadata: dict[str, Any] | None = None,
58-
extensions: list[str] | None = None,
5952
) -> AsyncIterator[ClientEvent]:
6053
"""Sends a message to the agent.
6154
@@ -65,35 +58,15 @@ async def send_message(
6558
6659
Args:
6760
request: The message to send to the agent.
68-
configuration: Optional per-call overrides for message sending behavior.
69-
context: The client call context.
70-
request_metadata: Extensions Metadata attached to the request.
71-
extensions: List of extensions to be activated.
61+
context: Optional client call context.
7262
7363
Yields:
7464
An async iterator of `ClientEvent`
7565
"""
76-
config = SendMessageConfiguration(
77-
accepted_output_modes=self._config.accepted_output_modes,
78-
blocking=not self._config.polling,
79-
task_push_notification_config=(
80-
self._config.push_notification_configs[0]
81-
if self._config.push_notification_configs
82-
else None
83-
),
84-
)
85-
86-
if configuration:
87-
config.MergeFrom(configuration)
88-
config.blocking = configuration.blocking
89-
90-
send_message_request = SendMessageRequest(
91-
message=request, configuration=config, metadata=request_metadata
92-
)
93-
66+
self._apply_client_config(request)
9467
if not self._config.streaming or not self._card.capabilities.streaming:
9568
response = await self._transport.send_message(
96-
send_message_request, context=context, extensions=extensions
69+
request, context=context
9770
)
9871

9972
# In non-streaming case we convert to a StreamResponse so that the
@@ -115,11 +88,29 @@ async def send_message(
11588
return
11689

11790
stream = self._transport.send_message_streaming(
118-
send_message_request, context=context, extensions=extensions
91+
request, context=context
11992
)
12093
async for client_event in self._process_stream(stream):
12194
yield client_event
12295

96+
def _apply_client_config(self, request: SendMessageRequest) -> None:
97+
if not request.configuration.blocking and self._config.polling:
98+
request.configuration.blocking = not self._config.polling
99+
if (
100+
not request.configuration.HasField('task_push_notification_config')
101+
and self._config.push_notification_configs
102+
):
103+
request.configuration.task_push_notification_config.CopyFrom(
104+
self._config.push_notification_configs[0]
105+
)
106+
if (
107+
not request.configuration.accepted_output_modes
108+
and self._config.accepted_output_modes
109+
):
110+
request.configuration.accepted_output_modes.extend(
111+
self._config.accepted_output_modes
112+
)
113+
123114
async def _process_stream(
124115
self, stream: AsyncIterator[StreamResponse]
125116
) -> AsyncGenerator[ClientEvent]:
@@ -146,21 +137,17 @@ async def get_task(
146137
request: GetTaskRequest,
147138
*,
148139
context: ClientCallContext | None = None,
149-
extensions: list[str] | None = None,
150140
) -> Task:
151141
"""Retrieves the current state and history of a specific task.
152142
153143
Args:
154144
request: The `GetTaskRequest` object specifying the task ID.
155-
context: The client call context.
156-
extensions: List of extensions to be activated.
145+
context: Optional client call context.
157146
158147
Returns:
159148
A `Task` object representing the current state of the task.
160149
"""
161-
return await self._transport.get_task(
162-
request, context=context, extensions=extensions
163-
)
150+
return await self._transport.get_task(request, context=context)
164151

165152
async def list_tasks(
166153
self,
@@ -176,118 +163,104 @@ async def cancel_task(
176163
request: CancelTaskRequest,
177164
*,
178165
context: ClientCallContext | None = None,
179-
extensions: list[str] | None = None,
180166
) -> Task:
181167
"""Requests the agent to cancel a specific task.
182168
183169
Args:
184170
request: The `CancelTaskRequest` object specifying the task ID.
185-
context: The client call context.
186-
extensions: List of extensions to be activated.
171+
context: Optional client call context.
187172
188173
Returns:
189174
A `Task` object containing the updated task status.
190175
"""
191-
return await self._transport.cancel_task(
192-
request, context=context, extensions=extensions
193-
)
176+
return await self._transport.cancel_task(request, context=context)
194177

195178
async def create_task_push_notification_config(
196179
self,
197180
request: TaskPushNotificationConfig,
198181
*,
199182
context: ClientCallContext | None = None,
200-
extensions: list[str] | None = None,
201183
) -> TaskPushNotificationConfig:
202184
"""Sets or updates the push notification configuration for a specific task.
203185
204186
Args:
205187
request: The `TaskPushNotificationConfig` object with the new configuration.
206-
context: The client call context.
207-
extensions: List of extensions to be activated.
188+
context: Optional client call context.
208189
209190
Returns:
210191
The created or updated `TaskPushNotificationConfig` object.
211192
"""
212193
return await self._transport.create_task_push_notification_config(
213-
request, context=context, extensions=extensions
194+
request, context=context
214195
)
215196

216197
async def get_task_push_notification_config(
217198
self,
218199
request: GetTaskPushNotificationConfigRequest,
219200
*,
220201
context: ClientCallContext | None = None,
221-
extensions: list[str] | None = None,
222202
) -> TaskPushNotificationConfig:
223203
"""Retrieves the push notification configuration for a specific task.
224204
225205
Args:
226206
request: The `GetTaskPushNotificationConfigParams` object specifying the task.
227-
context: The client call context.
228-
extensions: List of extensions to be activated.
207+
context: Optional client call context.
229208
230209
Returns:
231210
A `TaskPushNotificationConfig` object containing the configuration.
232211
"""
233212
return await self._transport.get_task_push_notification_config(
234-
request, context=context, extensions=extensions
213+
request, context=context
235214
)
236215

237216
async def list_task_push_notification_configs(
238217
self,
239218
request: ListTaskPushNotificationConfigsRequest,
240219
*,
241220
context: ClientCallContext | None = None,
242-
extensions: list[str] | None = None,
243221
) -> ListTaskPushNotificationConfigsResponse:
244222
"""Lists push notification configurations for a specific task.
245223
246224
Args:
247225
request: The `ListTaskPushNotificationConfigsRequest` object specifying the request.
248-
context: The client call context.
249-
extensions: List of extensions to be activated.
226+
context: Optional client call context.
250227
251228
Returns:
252229
A `ListTaskPushNotificationConfigsResponse` object.
253230
"""
254231
return await self._transport.list_task_push_notification_configs(
255-
request, context=context, extensions=extensions
232+
request, context=context
256233
)
257234

258235
async def delete_task_push_notification_config(
259236
self,
260237
request: DeleteTaskPushNotificationConfigRequest,
261238
*,
262239
context: ClientCallContext | None = None,
263-
extensions: list[str] | None = None,
264240
) -> None:
265241
"""Deletes the push notification configuration for a specific task.
266242
267243
Args:
268244
request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request.
269-
context: The client call context.
270-
extensions: List of extensions to be activated.
245+
context: Optional client call context.
271246
"""
272247
await self._transport.delete_task_push_notification_config(
273-
request, context=context, extensions=extensions
248+
request, context=context
274249
)
275250

276251
async def subscribe(
277252
self,
278253
request: SubscribeToTaskRequest,
279254
*,
280255
context: ClientCallContext | None = None,
281-
extensions: list[str] | None = None,
282256
) -> AsyncIterator[ClientEvent]:
283257
"""Resubscribes to a task's event stream.
284258
285259
This is only available if both the client and server support streaming.
286260
287261
Args:
288262
request: Parameters to identify the task to resubscribe to.
289-
context: The client call context.
290-
extensions: List of extensions to be activated.
263+
context: Optional client call context.
291264
292265
Yields:
293266
An async iterator of `ClientEvent` objects.
@@ -303,9 +276,7 @@ async def subscribe(
303276
# Note: resubscribe can only be called on an existing task. As such,
304277
# we should never see Message updates, despite the typing of the service
305278
# definition indicating it may be possible.
306-
stream = self._transport.subscribe(
307-
request, context=context, extensions=extensions
308-
)
279+
stream = self._transport.subscribe(request, context=context)
309280
async for client_event in self._process_stream(stream):
310281
yield client_event
311282

@@ -314,7 +285,6 @@ async def get_extended_agent_card(
314285
request: GetExtendedAgentCardRequest,
315286
*,
316287
context: ClientCallContext | None = None,
317-
extensions: list[str] | None = None,
318288
signature_verifier: Callable[[AgentCard], None] | None = None,
319289
) -> AgentCard:
320290
"""Retrieves the agent's card.
@@ -324,8 +294,7 @@ async def get_extended_agent_card(
324294
325295
Args:
326296
request: The `GetExtendedAgentCardRequest` object specifying the request.
327-
context: The client call context.
328-
extensions: List of extensions to be activated.
297+
context: Optional client call context.
329298
signature_verifier: A callable used to verify the agent card's signatures.
330299
331300
Returns:
@@ -334,7 +303,6 @@ async def get_extended_agent_card(
334303
card = await self._transport.get_extended_agent_card(
335304
request,
336305
context=context,
337-
extensions=extensions,
338306
signature_verifier=signature_verifier,
339307
)
340308
self._card = card

src/a2a/client/client.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
ListTaskPushNotificationConfigsResponse,
2424
ListTasksRequest,
2525
ListTasksResponse,
26-
Message,
27-
SendMessageConfiguration,
26+
SendMessageRequest,
2827
StreamResponse,
2928
SubscribeToTaskRequest,
3029
Task,
@@ -75,9 +74,6 @@ class ClientConfig:
7574
)
7675
"""Push notification configurations to use for every request."""
7776

78-
extensions: list[str] = dataclasses.field(default_factory=list)
79-
"""A list of extension URIs the client supports."""
80-
8177

8278
ClientEvent = tuple[StreamResponse, Task | None]
8379

@@ -128,12 +124,9 @@ async def __aexit__(
128124
@abstractmethod
129125
async def send_message(
130126
self,
131-
request: Message,
127+
request: SendMessageRequest,
132128
*,
133-
configuration: SendMessageConfiguration | None = None,
134129
context: ClientCallContext | None = None,
135-
request_metadata: dict[str, Any] | None = None,
136-
extensions: list[str] | None = None,
137130
) -> AsyncIterator[ClientEvent]:
138131
"""Sends a message to the server.
139132
@@ -152,7 +145,6 @@ async def get_task(
152145
request: GetTaskRequest,
153146
*,
154147
context: ClientCallContext | None = None,
155-
extensions: list[str] | None = None,
156148
) -> Task:
157149
"""Retrieves the current state and history of a specific task."""
158150

@@ -171,7 +163,6 @@ async def cancel_task(
171163
request: CancelTaskRequest,
172164
*,
173165
context: ClientCallContext | None = None,
174-
extensions: list[str] | None = None,
175166
) -> Task:
176167
"""Requests the agent to cancel a specific task."""
177168

@@ -181,7 +172,6 @@ async def create_task_push_notification_config(
181172
request: TaskPushNotificationConfig,
182173
*,
183174
context: ClientCallContext | None = None,
184-
extensions: list[str] | None = None,
185175
) -> TaskPushNotificationConfig:
186176
"""Sets or updates the push notification configuration for a specific task."""
187177

@@ -191,7 +181,6 @@ async def get_task_push_notification_config(
191181
request: GetTaskPushNotificationConfigRequest,
192182
*,
193183
context: ClientCallContext | None = None,
194-
extensions: list[str] | None = None,
195184
) -> TaskPushNotificationConfig:
196185
"""Retrieves the push notification configuration for a specific task."""
197186

@@ -201,7 +190,6 @@ async def list_task_push_notification_configs(
201190
request: ListTaskPushNotificationConfigsRequest,
202191
*,
203192
context: ClientCallContext | None = None,
204-
extensions: list[str] | None = None,
205193
) -> ListTaskPushNotificationConfigsResponse:
206194
"""Lists push notification configurations for a specific task."""
207195

@@ -211,7 +199,6 @@ async def delete_task_push_notification_config(
211199
request: DeleteTaskPushNotificationConfigRequest,
212200
*,
213201
context: ClientCallContext | None = None,
214-
extensions: list[str] | None = None,
215202
) -> None:
216203
"""Deletes the push notification configuration for a specific task."""
217204

@@ -221,7 +208,6 @@ async def subscribe(
221208
request: SubscribeToTaskRequest,
222209
*,
223210
context: ClientCallContext | None = None,
224-
extensions: list[str] | None = None,
225211
) -> AsyncIterator[ClientEvent]:
226212
"""Resubscribes to a task's event stream."""
227213
return
@@ -233,7 +219,6 @@ async def get_extended_agent_card(
233219
request: GetExtendedAgentCardRequest,
234220
*,
235221
context: ClientCallContext | None = None,
236-
extensions: list[str] | None = None,
237222
signature_verifier: Callable[[AgentCard], None] | None = None,
238223
) -> AgentCard:
239224
"""Retrieves the agent's card."""

0 commit comments

Comments
 (0)