Skip to content

Commit bf84e2c

Browse files
wuliang229copybara-github
authored andcommitted
feat(live): Add live_session_id to LlmResponse
This change introduces a `live_session_id` field to the `LlmResponse` dataclass and populates it in all responses generated by the `GeminiLlmConnection` when using a Live session. This allows tracking which Live session each response belongs to. Co-authored-by: Liang Wu <wuliang@google.com> PiperOrigin-RevId: 900354397
1 parent 5c6f6fe commit bf84e2c

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

src/google/adk/models/gemini_llm_connection.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ def __build_full_text_response(self, text: str):
176176
role='model',
177177
parts=[types.Part.from_text(text=text)],
178178
),
179+
live_session_id=self._gemini_session.session_id,
179180
)
180181

181182
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
@@ -192,11 +193,13 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
192193
# partial content and emit responses as needed.
193194
async for message in agen:
194195
logger.debug('Got LLM Live message: %s', message)
196+
live_session_id = self._gemini_session.session_id
195197
if message.usage_metadata:
196198
# Tracks token usage data per model.
197199
yield LlmResponse(
198200
usage_metadata=message.usage_metadata,
199201
model_version=self._model_version,
202+
live_session_id=live_session_id,
200203
)
201204
if message.server_content:
202205
content = message.server_content.model_turn
@@ -211,13 +214,15 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
211214
grounding_metadata=message.server_content.grounding_metadata,
212215
interrupted=message.server_content.interrupted,
213216
model_version=self._model_version,
217+
live_session_id=live_session_id,
214218
)
215219

216220
if content and content.parts:
217221
llm_response = LlmResponse(
218222
content=content,
219223
interrupted=message.server_content.interrupted,
220224
model_version=self._model_version,
225+
live_session_id=live_session_id,
221226
)
222227
# grounding_metadata is yielded again at turn_complete,
223228
# so avoid duplicating it here if turn_complete is true.
@@ -248,6 +253,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
248253
),
249254
partial=True,
250255
model_version=self._model_version,
256+
live_session_id=live_session_id,
251257
)
252258
# finished=True and partial transcription may happen in the same
253259
# message.
@@ -259,6 +265,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
259265
),
260266
partial=False,
261267
model_version=self._model_version,
268+
live_session_id=live_session_id,
262269
)
263270
self._input_transcription_text = ''
264271
if message.server_content.output_transcription:
@@ -273,6 +280,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
273280
),
274281
partial=True,
275282
model_version=self._model_version,
283+
live_session_id=live_session_id,
276284
)
277285
if message.server_content.output_transcription.finished:
278286
yield LlmResponse(
@@ -282,6 +290,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
282290
),
283291
partial=False,
284292
model_version=self._model_version,
293+
live_session_id=live_session_id,
285294
)
286295
self._output_transcription_text = ''
287296
# The Gemini API might not send a transcription finished signal.
@@ -300,6 +309,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
300309
),
301310
partial=False,
302311
model_version=self._model_version,
312+
live_session_id=live_session_id,
303313
)
304314
self._input_transcription_text = ''
305315
if self._output_transcription_text:
@@ -310,6 +320,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
310320
),
311321
partial=False,
312322
model_version=self._model_version,
323+
live_session_id=live_session_id,
313324
)
314325
self._output_transcription_text = ''
315326
if message.server_content.turn_complete:
@@ -321,13 +332,15 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
321332
yield LlmResponse(
322333
content=types.Content(role='model', parts=tool_call_parts),
323334
model_version=self._model_version,
335+
live_session_id=live_session_id,
324336
)
325337
tool_call_parts = []
326338
yield LlmResponse(
327339
turn_complete=True,
328340
interrupted=message.server_content.interrupted,
329341
grounding_metadata=message.server_content.grounding_metadata,
330342
model_version=self._model_version,
343+
live_session_id=live_session_id,
331344
)
332345
break
333346
# in case of empty content or parts, we still surface it
@@ -342,6 +355,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
342355
yield LlmResponse(
343356
interrupted=message.server_content.interrupted,
344357
model_version=self._model_version,
358+
live_session_id=live_session_id,
345359
)
346360
if message.tool_call:
347361
logger.debug('Received tool call: %s', message.tool_call)
@@ -358,20 +372,23 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
358372
LlmResponse(
359373
live_session_resumption_update=message.session_resumption_update,
360374
model_version=self._model_version,
375+
live_session_id=live_session_id,
361376
)
362377
)
363378
if message.go_away:
364379
logger.debug('Received GoAway message: %s', message.go_away)
365380
yield LlmResponse(
366381
go_away=message.go_away,
367382
model_version=self._model_version,
383+
live_session_id=live_session_id,
368384
)
369385

370386
if tool_call_parts:
371387
logger.debug('Exited loop with pending tool_call_parts')
372388
yield LlmResponse(
373389
content=types.Content(role='model', parts=tool_call_parts),
374390
model_version=self._model_version,
391+
live_session_id=self._gemini_session.session_id,
375392
)
376393

377394
async def close(self):

src/google/adk/models/llm_response.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ class LlmResponse(BaseModel):
110110
] = None
111111
"""The session resumption update of the LlmResponse"""
112112

113+
live_session_id: Optional[str] = None
114+
"""The session ID of the Live session."""
115+
113116
go_away: Optional[types.LiveServerGoAway] = None
114117
"""The GoAway signal from the Live model."""
115118

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
@pytest.fixture
2626
def mock_gemini_session():
2727
"""Mock Gemini session for testing."""
28-
return mock.AsyncMock()
28+
mock_session = mock.AsyncMock()
29+
mock_session.session_id = 'test-session-id'
30+
return mock_session
2931

3032

3133
@pytest.fixture
@@ -247,6 +249,41 @@ async def mock_receive_generator():
247249
assert content_response.content == mock_content
248250

249251

252+
async def test_receive_populates_live_session_id(
253+
gemini_connection, mock_gemini_session
254+
):
255+
"""Test that receive populates live_session_id in LlmResponse."""
256+
mock_message = mock.AsyncMock()
257+
mock_message.usage_metadata = None
258+
mock_message.server_content = None
259+
mock_message.tool_call = None
260+
mock_message.session_resumption_update = None
261+
mock_message.go_away = None
262+
263+
mock_server_content = mock.Mock()
264+
mock_server_content.model_turn = types.Content(
265+
role='model', parts=[types.Part.from_text(text='text')]
266+
)
267+
mock_server_content.interrupted = False
268+
mock_server_content.input_transcription = None
269+
mock_server_content.output_transcription = None
270+
mock_server_content.turn_complete = False
271+
mock_server_content.grounding_metadata = None
272+
273+
mock_message.server_content = mock_server_content
274+
275+
async def mock_receive_generator():
276+
yield mock_message
277+
278+
mock_gemini_session.receive = mock.Mock(return_value=mock_receive_generator())
279+
280+
responses = [resp async for resp in gemini_connection.receive()]
281+
282+
assert responses
283+
for resp in responses:
284+
assert resp.live_session_id == 'test-session-id'
285+
286+
250287
@pytest.mark.asyncio
251288
async def test_receive_transcript_finished_on_interrupt(
252289
gemini_api_connection,

0 commit comments

Comments
 (0)