Skip to content

Commit 94537c3

Browse files
authored
fix(client): do not mutate SendMessageRequest in BaseClient.send_message (#949)
Updating passed parameter by reference is not great.
1 parent 617fdf3 commit 94537c3

2 files changed

Lines changed: 52 additions & 10 deletions

File tree

src/a2a/client/base_client.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ async def send_message(
6666
Yields:
6767
An async iterator of `StreamResponse`
6868
"""
69-
self._apply_client_config(request)
69+
request = self._apply_client_config(request)
7070
if not self._config.streaming or not self._card.capabilities.streaming:
7171
response = await self._execute_with_interceptors(
7272
input_data=request,
@@ -100,22 +100,29 @@ async def send_message(
100100
):
101101
yield event
102102

103-
def _apply_client_config(self, request: SendMessageRequest) -> None:
104-
request.configuration.return_immediately |= self._config.polling
105-
if (
106-
not request.configuration.HasField('task_push_notification_config')
107-
and self._config.push_notification_configs
103+
def _apply_client_config(
104+
self, request: SendMessageRequest
105+
) -> SendMessageRequest:
106+
modified_request = SendMessageRequest()
107+
modified_request.CopyFrom(request)
108+
if self._config.polling:
109+
modified_request.configuration.return_immediately = True
110+
if self._config.push_notification_configs and (
111+
not modified_request.configuration.HasField(
112+
'task_push_notification_config'
113+
)
108114
):
109-
request.configuration.task_push_notification_config.CopyFrom(
115+
modified_request.configuration.task_push_notification_config.CopyFrom(
110116
self._config.push_notification_configs[0]
111117
)
112118
if (
113-
not request.configuration.accepted_output_modes
114-
and self._config.accepted_output_modes
119+
self._config.accepted_output_modes
120+
and not modified_request.configuration.accepted_output_modes
115121
):
116-
request.configuration.accepted_output_modes.extend(
122+
modified_request.configuration.accepted_output_modes.extend(
117123
self._config.accepted_output_modes
118124
)
125+
return modified_request
119126

120127
async def _process_stream(
121128
self,

tests/client/test_base_client.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,41 @@ async def test_send_message_non_streaming_agent_capability_false(
208208
response = events[0]
209209
assert response.task.id == 'task-789'
210210

211+
@pytest.mark.asyncio
212+
async def test_send_message_does_not_mutate_request(
213+
self,
214+
base_client: BaseClient,
215+
mock_transport: MagicMock,
216+
sample_message: Message,
217+
):
218+
base_client._config.streaming = False
219+
base_client._config.polling = True
220+
base_client._config.accepted_output_modes = ['application/json']
221+
base_client._config.push_notification_configs = [
222+
TaskPushNotificationConfig(
223+
task_id='task-1',
224+
)
225+
]
226+
227+
task = Task(
228+
id='task-no-mutate',
229+
context_id='ctx-no-mutate',
230+
status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED),
231+
)
232+
response = SendMessageResponse()
233+
response.task.CopyFrom(task)
234+
mock_transport.send_message.return_value = response
235+
236+
request = SendMessageRequest(message=sample_message)
237+
238+
original = SendMessageRequest()
239+
original.CopyFrom(request)
240+
241+
events = [event async for event in base_client.send_message(request)]
242+
assert len(events) == 1
243+
244+
assert request == original
245+
211246
@pytest.mark.asyncio
212247
async def test_send_message_callsite_config_overrides_non_streaming(
213248
self,

0 commit comments

Comments
 (0)