Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion litellm/responses/litellm_completion_transformation/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,16 @@ async def async_response_api_handler(
)

# CARTO PATCH: Store session immediately in Redis to avoid batch processing delay
# Store BOTH input messages AND the assistant response to preserve tool_calls with thought_signature
if responses_api_response.id:
session_id = kwargs.get("litellm_trace_id") or str(uuid.uuid4())
current_messages = litellm_completion_request.get("messages", [])
current_messages = list(litellm_completion_request.get("messages", []))

# Extract assistant message from the completion response
assistant_message = _extract_assistant_message_from_response(litellm_completion_response)
if assistant_message:
current_messages.append(assistant_message)

await LiteLLMCompletionResponsesConfig._patch_store_session_in_redis(
response_id=responses_api_response.id,
session_id=session_id,
Expand All @@ -153,3 +160,78 @@ async def async_response_api_handler(
litellm_metadata=kwargs.get("litellm_metadata", {}),
litellm_completion_request=litellm_completion_request,
)


def _extract_assistant_message_from_response(response: ModelResponse) -> Optional[Dict]:
"""
Extract the assistant message from a ModelResponse, preserving tool_calls with provider_specific_fields.

This is critical for Gemini thinking models where thought_signature must be preserved
for multi-turn conversations with tool calling.
"""
if not response.choices:
return None

choice = response.choices[0]
if not hasattr(choice, 'message') or not choice.message:
return None

msg = choice.message

assistant_message: Dict[str, Any] = {
"role": "assistant",
}

# Add content if present
content = getattr(msg, 'content', None)
if content is not None:
assistant_message["content"] = content

# Add tool_calls with provider_specific_fields (contains thought_signature for Gemini)
tool_calls = getattr(msg, 'tool_calls', None)
if tool_calls:
serialized_tool_calls = []
for tc in tool_calls:
tool_call_dict: Dict[str, Any] = {
"id": getattr(tc, 'id', None) or tc.get('id') if isinstance(tc, dict) else getattr(tc, 'id', None),
"type": getattr(tc, 'type', 'function') or tc.get('type', 'function') if isinstance(tc, dict) else getattr(tc, 'type', 'function'),
}

# Extract function details
fn = tc.get('function') if isinstance(tc, dict) else getattr(tc, 'function', None)
if fn:
fn_dict: Dict[str, Any] = {
"name": fn.get('name') if isinstance(fn, dict) else getattr(fn, 'name', ''),
"arguments": fn.get('arguments', '') if isinstance(fn, dict) else getattr(fn, 'arguments', ''),
}
# Preserve function-level provider_specific_fields (contains thought_signature)
fn_provider_fields = fn.get('provider_specific_fields') if isinstance(fn, dict) else getattr(fn, 'provider_specific_fields', None)
if fn_provider_fields:
fn_dict["provider_specific_fields"] = fn_provider_fields
tool_call_dict["function"] = fn_dict

# Also check for tool_call-level provider_specific_fields
tc_provider_fields = tc.get('provider_specific_fields') if isinstance(tc, dict) else getattr(tc, 'provider_specific_fields', None)
if tc_provider_fields:
tool_call_dict["provider_specific_fields"] = tc_provider_fields

serialized_tool_calls.append(tool_call_dict)

assistant_message["tool_calls"] = serialized_tool_calls

# Preserve message-level provider_specific_fields
msg_provider_fields = getattr(msg, 'provider_specific_fields', None)
if msg_provider_fields:
assistant_message["provider_specific_fields"] = msg_provider_fields

# Preserve thinking_blocks if present (Anthropic/Claude)
thinking_blocks = getattr(msg, 'thinking_blocks', None)
if thinking_blocks:
assistant_message["thinking_blocks"] = thinking_blocks

# Preserve reasoning_content if present
reasoning_content = getattr(msg, 'reasoning_content', None)
if reasoning_content:
assistant_message["reasoning_content"] = reasoning_content

return assistant_message
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import time
import uuid
from typing import List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Union, cast

import litellm
from litellm.main import stream_chunk_builder
Expand Down Expand Up @@ -1062,34 +1062,118 @@ def _emit_response_completed_event(
async def _store_session_in_redis(self, response_completed_event: ResponseCompletedEvent):
"""
PATCH: Store session in Redis for streaming responses
This fixes the issue where Redis sessions weren't created for streaming requests
This fixes the issue where Redis sessions weren't created for streaming requests.

For Gemini thinking models with tool calling, we must preserve:
- tool_calls with provider_specific_fields (contains thought_signature)
- message-level provider_specific_fields
"""
try:
response = response_completed_event.response
if response and response.id:
# Get the session ID from metadata or from the completion request
session_id = (self.litellm_completion_request.get("litellm_trace_id") or
self.litellm_metadata.get("litellm_trace_id") or
session_id = (self.litellm_completion_request.get("litellm_trace_id") or
self.litellm_metadata.get("litellm_trace_id") or
str(uuid.uuid4()))

# Get the full messages from the completion request (includes history)
messages = self.litellm_completion_request.get("messages", []).copy()

# Add the assistant response to the messages
if response.output and len(response.output) > 0:
output_item = response.output[0]
if output_item.content and len(output_item.content) > 0:
content_item = output_item.content[0]
if hasattr(content_item, "text"):
messages.append({"role": "assistant", "content": content_item.text})

messages = list(self.litellm_completion_request.get("messages", []))

# Extract assistant message from the reconstructed ModelResponse
# This preserves tool_calls with provider_specific_fields (thought_signature)
assistant_message = self._extract_assistant_message_for_redis()
if assistant_message:
messages.append(assistant_message)

# Store session in Redis
await LiteLLMCompletionResponsesConfig._patch_store_session_in_redis(
response_id=response.id,
session_id=session_id,
messages=messages
)
except Exception as e:
except Exception:
# Silently fail - Redis storage is a patch for timing issues
# and shouldn't break the streaming response
pass

def _extract_assistant_message_for_redis(self) -> Optional[Dict[str, Any]]:
"""
Extract the assistant message from the reconstructed ModelResponse,
preserving tool_calls with provider_specific_fields.

This is critical for Gemini thinking models where thought_signature must be preserved
for multi-turn conversations with tool calling.
"""
if not self.litellm_model_response:
return None

if not hasattr(self.litellm_model_response, 'choices') or not self.litellm_model_response.choices:
return None

choice = self.litellm_model_response.choices[0]
if not hasattr(choice, 'message') or not choice.message:
return None

msg = choice.message

assistant_message: Dict[str, Any] = {
"role": "assistant",
}

# Add content if present
content = getattr(msg, 'content', None)
if content is not None:
assistant_message["content"] = content

# Add tool_calls with provider_specific_fields (contains thought_signature for Gemini)
tool_calls = getattr(msg, 'tool_calls', None)
if tool_calls:
serialized_tool_calls = []
for tc in tool_calls:
tool_call_dict: Dict[str, Any] = {
"id": getattr(tc, 'id', None) or tc.get('id') if isinstance(tc, dict) else getattr(tc, 'id', None),
"type": getattr(tc, 'type', 'function') or tc.get('type', 'function') if isinstance(tc, dict) else getattr(tc, 'type', 'function'),
}

# Extract function details
fn = tc.get('function') if isinstance(tc, dict) else getattr(tc, 'function', None)
if fn:
fn_dict: Dict[str, Any] = {
"name": fn.get('name') if isinstance(fn, dict) else getattr(fn, 'name', ''),
"arguments": fn.get('arguments', '') if isinstance(fn, dict) else getattr(fn, 'arguments', ''),
}
# Preserve function-level provider_specific_fields (contains thought_signature)
fn_provider_fields = fn.get('provider_specific_fields') if isinstance(fn, dict) else getattr(fn, 'provider_specific_fields', None)
if fn_provider_fields:
fn_dict["provider_specific_fields"] = fn_provider_fields
tool_call_dict["function"] = fn_dict

# Also check for tool_call-level provider_specific_fields
tc_provider_fields = tc.get('provider_specific_fields') if isinstance(tc, dict) else getattr(tc, 'provider_specific_fields', None)
if tc_provider_fields:
tool_call_dict["provider_specific_fields"] = tc_provider_fields

serialized_tool_calls.append(tool_call_dict)

assistant_message["tool_calls"] = serialized_tool_calls

# Preserve message-level provider_specific_fields
msg_provider_fields = getattr(msg, 'provider_specific_fields', None)
if msg_provider_fields:
assistant_message["provider_specific_fields"] = msg_provider_fields

# Preserve thinking_blocks if present (Anthropic/Claude)
thinking_blocks = getattr(msg, 'thinking_blocks', None)
if thinking_blocks:
assistant_message["thinking_blocks"] = thinking_blocks

# Preserve reasoning_content if present
reasoning_content = getattr(msg, 'reasoning_content', None)
if reasoning_content:
assistant_message["reasoning_content"] = reasoning_content

# If we have no meaningful content, return None
if not content and not tool_calls:
return None

return assistant_message
Loading