11from collections .abc import AsyncGenerator , AsyncIterator , Callable
2- from typing import Any
32
43from a2a .client .client import (
54 Client ,
6- ClientCallContext ,
75 ClientConfig ,
86 ClientEvent ,
97 Consumer ,
108)
119from a2a .client .client_task_manager import ClientTaskManager
12- from a2a .client .middleware import ClientCallInterceptor
10+ from a2a .client .middleware import ClientCallContext , ClientCallInterceptor
1311from a2a .client .transports .base import ClientTransport
1412from a2a .types .a2a_pb2 import (
1513 AgentCard ,
2220 ListTaskPushNotificationConfigsResponse ,
2321 ListTasksRequest ,
2422 ListTasksResponse ,
25- Message ,
26- SendMessageConfiguration ,
2723 SendMessageRequest ,
2824 StreamResponse ,
2925 SubscribeToTaskRequest ,
@@ -50,12 +46,9 @@ def __init__(
5046
5147 async def send_message (
5248 self ,
53- request : Message ,
49+ request : SendMessageRequest ,
5450 * ,
55- configuration : SendMessageConfiguration | None = None ,
5651 context : ClientCallContext | None = None ,
57- request_metadata : dict [str , Any ] | None = None ,
58- extensions : list [str ] | None = None ,
5952 ) -> AsyncIterator [ClientEvent ]:
6053 """Sends a message to the agent.
6154
@@ -65,35 +58,15 @@ async def send_message(
6558
6659 Args:
6760 request: The message to send to the agent.
68- configuration: Optional per-call overrides for message sending behavior.
69- context: The client call context.
70- request_metadata: Extensions Metadata attached to the request.
71- extensions: List of extensions to be activated.
61+ context: Optional client call context.
7262
7363 Yields:
7464 An async iterator of `ClientEvent`
7565 """
76- config = SendMessageConfiguration (
77- accepted_output_modes = self ._config .accepted_output_modes ,
78- blocking = not self ._config .polling ,
79- task_push_notification_config = (
80- self ._config .push_notification_configs [0 ]
81- if self ._config .push_notification_configs
82- else None
83- ),
84- )
85-
86- if configuration :
87- config .MergeFrom (configuration )
88- config .blocking = configuration .blocking
89-
90- send_message_request = SendMessageRequest (
91- message = request , configuration = config , metadata = request_metadata
92- )
93-
66+ self ._apply_client_config (request )
9467 if not self ._config .streaming or not self ._card .capabilities .streaming :
9568 response = await self ._transport .send_message (
96- send_message_request , context = context , extensions = extensions
69+ request , context = context
9770 )
9871
9972 # In non-streaming case we convert to a StreamResponse so that the
@@ -115,11 +88,29 @@ async def send_message(
11588 return
11689
11790 stream = self ._transport .send_message_streaming (
118- send_message_request , context = context , extensions = extensions
91+ request , context = context
11992 )
12093 async for client_event in self ._process_stream (stream ):
12194 yield client_event
12295
96+ def _apply_client_config (self , request : SendMessageRequest ) -> None :
97+ if not request .configuration .blocking and self ._config .polling :
98+ request .configuration .blocking = not self ._config .polling
99+ if (
100+ not request .configuration .HasField ('task_push_notification_config' )
101+ and self ._config .push_notification_configs
102+ ):
103+ request .configuration .task_push_notification_config .CopyFrom (
104+ self ._config .push_notification_configs [0 ]
105+ )
106+ if (
107+ not request .configuration .accepted_output_modes
108+ and self ._config .accepted_output_modes
109+ ):
110+ request .configuration .accepted_output_modes .extend (
111+ self ._config .accepted_output_modes
112+ )
113+
123114 async def _process_stream (
124115 self , stream : AsyncIterator [StreamResponse ]
125116 ) -> AsyncGenerator [ClientEvent ]:
@@ -146,21 +137,17 @@ async def get_task(
146137 request : GetTaskRequest ,
147138 * ,
148139 context : ClientCallContext | None = None ,
149- extensions : list [str ] | None = None ,
150140 ) -> Task :
151141 """Retrieves the current state and history of a specific task.
152142
153143 Args:
154144 request: The `GetTaskRequest` object specifying the task ID.
155- context: The client call context.
156- extensions: List of extensions to be activated.
145+ context: Optional client call context.
157146
158147 Returns:
159148 A `Task` object representing the current state of the task.
160149 """
161- return await self ._transport .get_task (
162- request , context = context , extensions = extensions
163- )
150+ return await self ._transport .get_task (request , context = context )
164151
165152 async def list_tasks (
166153 self ,
@@ -176,118 +163,104 @@ async def cancel_task(
176163 request : CancelTaskRequest ,
177164 * ,
178165 context : ClientCallContext | None = None ,
179- extensions : list [str ] | None = None ,
180166 ) -> Task :
181167 """Requests the agent to cancel a specific task.
182168
183169 Args:
184170 request: The `CancelTaskRequest` object specifying the task ID.
185- context: The client call context.
186- extensions: List of extensions to be activated.
171+ context: Optional client call context.
187172
188173 Returns:
189174 A `Task` object containing the updated task status.
190175 """
191- return await self ._transport .cancel_task (
192- request , context = context , extensions = extensions
193- )
176+ return await self ._transport .cancel_task (request , context = context )
194177
195178 async def create_task_push_notification_config (
196179 self ,
197180 request : TaskPushNotificationConfig ,
198181 * ,
199182 context : ClientCallContext | None = None ,
200- extensions : list [str ] | None = None ,
201183 ) -> TaskPushNotificationConfig :
202184 """Sets or updates the push notification configuration for a specific task.
203185
204186 Args:
205187 request: The `TaskPushNotificationConfig` object with the new configuration.
206- context: The client call context.
207- extensions: List of extensions to be activated.
188+ context: Optional client call context.
208189
209190 Returns:
210191 The created or updated `TaskPushNotificationConfig` object.
211192 """
212193 return await self ._transport .create_task_push_notification_config (
213- request , context = context , extensions = extensions
194+ request , context = context
214195 )
215196
216197 async def get_task_push_notification_config (
217198 self ,
218199 request : GetTaskPushNotificationConfigRequest ,
219200 * ,
220201 context : ClientCallContext | None = None ,
221- extensions : list [str ] | None = None ,
222202 ) -> TaskPushNotificationConfig :
223203 """Retrieves the push notification configuration for a specific task.
224204
225205 Args:
226206 request: The `GetTaskPushNotificationConfigParams` object specifying the task.
227- context: The client call context.
228- extensions: List of extensions to be activated.
207+ context: Optional client call context.
229208
230209 Returns:
231210 A `TaskPushNotificationConfig` object containing the configuration.
232211 """
233212 return await self ._transport .get_task_push_notification_config (
234- request , context = context , extensions = extensions
213+ request , context = context
235214 )
236215
237216 async def list_task_push_notification_configs (
238217 self ,
239218 request : ListTaskPushNotificationConfigsRequest ,
240219 * ,
241220 context : ClientCallContext | None = None ,
242- extensions : list [str ] | None = None ,
243221 ) -> ListTaskPushNotificationConfigsResponse :
244222 """Lists push notification configurations for a specific task.
245223
246224 Args:
247225 request: The `ListTaskPushNotificationConfigsRequest` object specifying the request.
248- context: The client call context.
249- extensions: List of extensions to be activated.
226+ context: Optional client call context.
250227
251228 Returns:
252229 A `ListTaskPushNotificationConfigsResponse` object.
253230 """
254231 return await self ._transport .list_task_push_notification_configs (
255- request , context = context , extensions = extensions
232+ request , context = context
256233 )
257234
258235 async def delete_task_push_notification_config (
259236 self ,
260237 request : DeleteTaskPushNotificationConfigRequest ,
261238 * ,
262239 context : ClientCallContext | None = None ,
263- extensions : list [str ] | None = None ,
264240 ) -> None :
265241 """Deletes the push notification configuration for a specific task.
266242
267243 Args:
268244 request: The `DeleteTaskPushNotificationConfigRequest` object specifying the request.
269- context: The client call context.
270- extensions: List of extensions to be activated.
245+ context: Optional client call context.
271246 """
272247 await self ._transport .delete_task_push_notification_config (
273- request , context = context , extensions = extensions
248+ request , context = context
274249 )
275250
276251 async def subscribe (
277252 self ,
278253 request : SubscribeToTaskRequest ,
279254 * ,
280255 context : ClientCallContext | None = None ,
281- extensions : list [str ] | None = None ,
282256 ) -> AsyncIterator [ClientEvent ]:
283257 """Resubscribes to a task's event stream.
284258
285259 This is only available if both the client and server support streaming.
286260
287261 Args:
288262 request: Parameters to identify the task to resubscribe to.
289- context: The client call context.
290- extensions: List of extensions to be activated.
263+ context: Optional client call context.
291264
292265 Yields:
293266 An async iterator of `ClientEvent` objects.
@@ -303,9 +276,7 @@ async def subscribe(
303276 # Note: resubscribe can only be called on an existing task. As such,
304277 # we should never see Message updates, despite the typing of the service
305278 # definition indicating it may be possible.
306- stream = self ._transport .subscribe (
307- request , context = context , extensions = extensions
308- )
279+ stream = self ._transport .subscribe (request , context = context )
309280 async for client_event in self ._process_stream (stream ):
310281 yield client_event
311282
@@ -314,7 +285,6 @@ async def get_extended_agent_card(
314285 request : GetExtendedAgentCardRequest ,
315286 * ,
316287 context : ClientCallContext | None = None ,
317- extensions : list [str ] | None = None ,
318288 signature_verifier : Callable [[AgentCard ], None ] | None = None ,
319289 ) -> AgentCard :
320290 """Retrieves the agent's card.
@@ -324,8 +294,7 @@ async def get_extended_agent_card(
324294
325295 Args:
326296 request: The `GetExtendedAgentCardRequest` object specifying the request.
327- context: The client call context.
328- extensions: List of extensions to be activated.
297+ context: Optional client call context.
329298 signature_verifier: A callable used to verify the agent card's signatures.
330299
331300 Returns:
@@ -334,7 +303,6 @@ async def get_extended_agent_card(
334303 card = await self ._transport .get_extended_agent_card (
335304 request ,
336305 context = context ,
337- extensions = extensions ,
338306 signature_verifier = signature_verifier ,
339307 )
340308 self ._card = card
0 commit comments