Skip to content

Commit d6594a1

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add support for refusal messages in ApigeeLlm
If content and refusal chunks are interleaved, this will drop the remaining content chunks after the first refusal chunk appears. PiperOrigin-RevId: 901457248
1 parent 782796f commit d6594a1

3 files changed

Lines changed: 200 additions & 17 deletions

File tree

src/google/adk/models/apigee_llm.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
'object',
6161
)
6262

63+
_REFUSAL_PREFIX = '[[REFUSAL]]: '
64+
6365

6466
class ApigeeLlm(Gemini):
6567
"""A BaseLlm implementation for calling Apigee proxy.
@@ -658,11 +660,14 @@ def _content_to_messages(
658660

659661
tool_calls = []
660662
content_parts = []
663+
refusals: list[str] = []
661664

662665
function_responses = []
663666

664667
for part in content.parts or []:
665-
self._process_content_part(content, part, tool_calls, content_parts)
668+
self._process_content_part(
669+
content, part, tool_calls, content_parts, refusals
670+
)
666671
if part.function_response:
667672
function_responses.append({
668673
'role': 'tool',
@@ -673,6 +678,8 @@ def _content_to_messages(
673678
return function_responses
674679

675680
message = {'role': role}
681+
if refusals:
682+
message['refusal'] = '\n'.join(refusals)
676683
if tool_calls:
677684
message['tool_calls'] = tool_calls
678685
if not content_parts:
@@ -691,6 +698,7 @@ def _process_content_part(
691698
part: types.Part,
692699
tool_calls: list[dict[str, Any]],
693700
content_parts: list[dict[str, Any]],
701+
refusals: list[str],
694702
) -> None:
695703
"""Processes a single Part and updates tool_calls or content_parts."""
696704
if content.role != 'user' and (
@@ -731,7 +739,14 @@ def _process_content_part(
731739
# Handled in the loop to return immediately
732740
pass
733741
elif part.text:
734-
content_parts.append({'type': 'text', 'text': part.text})
742+
if part.text.startswith(_REFUSAL_PREFIX):
743+
refusals.append(part.text.removeprefix(_REFUSAL_PREFIX))
744+
else:
745+
before, sep, after = part.text.partition('\n' + _REFUSAL_PREFIX)
746+
if sep:
747+
refusals.append(after)
748+
if before:
749+
content_parts.append({'type': 'text', 'text': before})
735750
elif part.inline_data:
736751
mime_type = part.inline_data.mime_type
737752
data = base64.b64encode(part.inline_data.data).decode('utf-8')
@@ -843,6 +858,7 @@ def __init__(self):
843858
self.usage = {}
844859
self.logprobs = {}
845860
self.custom_metadata = {}
861+
self._refusal_started = False
846862

847863
def process_response(self, response: dict[str, Any]) -> LlmResponse:
848864
"""Processes a complete non-streaming response."""
@@ -989,19 +1005,49 @@ def _accumulate_logprobs(self, logprobs_chunk: dict[str, Any]) -> None:
9891005
self.logprobs['refusal'] = []
9901006
self.logprobs['refusal'].extend(logprobs_chunk['refusal'])
9911007

992-
def _append_content(self, content: str, refusal: str) -> str:
993-
if content and refusal:
994-
content += '\n'
995-
content += refusal
996-
elif refusal:
997-
content = refusal
1008+
def _accumulate_content(self, choice: dict[str, Any]) -> str:
1009+
"""Processes a message or delta chunk to accumulate content and refusals.
1010+
1011+
This method extracts 'content' and 'refusal' from the chunk, updates the
1012+
accumulated state (self.content_parts), and returns the text content for
1013+
this chunk (handling prefixes and newlines if it's a refusal).
1014+
1015+
Args:
1016+
choice: A dictionary representing a message choice or a streaming delta.
1017+
1018+
Returns:
1019+
The text content to be appended or yielded for this chunk.
1020+
"""
1021+
content = choice.get('content', '')
1022+
refusal = choice.get('refusal', '')
1023+
1024+
if content and self._refusal_started:
1025+
logging.warning(
1026+
'Received content after refusal has started. Dropping content.'
1027+
)
1028+
content = ''
1029+
1030+
chunk_text = ''
9981031
if content:
999-
self.content_parts += content
1000-
return content
1032+
chunk_text += content
1033+
1034+
if refusal and not self._refusal_started:
1035+
self._refusal_started = True
1036+
if self.content_parts or chunk_text:
1037+
chunk_text += '\n'
1038+
chunk_text += _REFUSAL_PREFIX
1039+
1040+
if refusal:
1041+
chunk_text += refusal
1042+
1043+
if chunk_text:
1044+
self.content_parts += chunk_text
1045+
1046+
return chunk_text
10011047

10021048
def _add_chat_completion_chunk_delta(
10031049
self, delta: dict[str, Any]
1004-
) -> (list[types.Part], str):
1050+
) -> tuple[list[types.Part], str]:
10051051
"""Adds a chunk delta from a streaming chat completions response.
10061052
10071053
This method processes a single delta chunk from a streaming chat completions
@@ -1021,9 +1067,7 @@ def _add_chat_completion_chunk_delta(
10211067
for tool_call in delta.get('tool_calls', []):
10221068
chunk_part = self._upsert_tool_call(tool_call)
10231069
parts.append(chunk_part)
1024-
content = delta.get('content')
1025-
refusal = delta.get('refusal')
1026-
merged_content = self._append_content(content, refusal)
1070+
merged_content = self._accumulate_content(delta)
10271071
if merged_content:
10281072
parts.append(types.Part.from_text(text=merged_content))
10291073

@@ -1057,9 +1101,7 @@ def _add_chat_completion_message(
10571101
'type': 'function',
10581102
'function': function_call,
10591103
})
1060-
content = message.get('content')
1061-
refusal = message.get('refusal')
1062-
self._append_content(content, refusal)
1104+
self._accumulate_content(message)
10631105

10641106
self._get_or_create_role(message.get('role', 'model'))
10651107
return self._get_content_parts(), self.role

tests/unittests/models/test_apigee_llm.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,3 +649,86 @@ def test_parse_response_usage_metadata():
649649
assert llm_response.usage_metadata.candidates_token_count == 5
650650
assert llm_response.usage_metadata.total_token_count == 15
651651
assert llm_response.usage_metadata.thoughts_token_count == 4
652+
653+
654+
def test_parse_response_with_refusal():
655+
"""Tests that CompletionsHTTPClient parses refusal correctly."""
656+
client = CompletionsHTTPClient(base_url='http://test')
657+
658+
response_dict = {
659+
'choices': [{
660+
'message': {
661+
'role': 'assistant',
662+
'refusal': 'I refuse to answer',
663+
},
664+
'finish_reason': 'stop',
665+
}],
666+
}
667+
llm_response = client._parse_response(response_dict)
668+
assert len(llm_response.content.parts) == 1
669+
assert llm_response.content.parts[0].text == '[[REFUSAL]]: I refuse to answer'
670+
671+
response_dict_mixed = {
672+
'choices': [{
673+
'message': {
674+
'role': 'assistant',
675+
'content': 'Here is some content',
676+
'refusal': 'But I refuse to answer the rest',
677+
},
678+
'finish_reason': 'stop',
679+
}],
680+
}
681+
llm_response_mixed = client._parse_response(response_dict_mixed)
682+
assert len(llm_response_mixed.content.parts) == 1
683+
assert (
684+
llm_response_mixed.content.parts[0].text
685+
== 'Here is some content\n[[REFUSAL]]: But I refuse to answer the rest'
686+
)
687+
688+
689+
@pytest.mark.parametrize(
690+
('parts', 'expected_message'),
691+
[
692+
(
693+
[
694+
types.Part.from_text(text='[[REFUSAL]]: I refuse to answer'),
695+
types.Part.from_text(text='normal content'),
696+
],
697+
{
698+
'role': 'assistant',
699+
'refusal': 'I refuse to answer',
700+
'content': 'normal content',
701+
},
702+
),
703+
(
704+
[
705+
types.Part.from_text(
706+
text=(
707+
'Here is some content\n[[REFUSAL]]: But I refuse to'
708+
' answer the rest'
709+
)
710+
),
711+
],
712+
{
713+
'role': 'assistant',
714+
'refusal': 'But I refuse to answer the rest',
715+
'content': 'Here is some content',
716+
},
717+
),
718+
],
719+
)
720+
def test_construct_payload_with_refusal(parts, expected_message):
721+
"""Tests that CompletionsHTTPClient constructs payload with refusal correctly."""
722+
client = CompletionsHTTPClient(base_url='http://test')
723+
req = LlmRequest(
724+
model='apigee/openai/gpt-4o',
725+
contents=[
726+
types.Content(
727+
role='model',
728+
parts=parts,
729+
)
730+
],
731+
)
732+
payload = client._construct_payload(req, stream=False)
733+
messages = payload['messages']
734+
assert messages == [expected_message]

tests/unittests/models/test_completions_http_client.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717
from unittest.mock import AsyncMock
1818

19+
from google.adk.models.apigee_llm import ChatCompletionsResponseHandler
1920
from google.adk.models.apigee_llm import CompletionsHTTPClient
2021
from google.adk.models.llm_request import LlmRequest
2122
from google.genai import types
@@ -771,3 +772,60 @@ async def mock_aiter_lines():
771772
]
772773
assert len(responses) == expected_response_count
773774
assert responses[0].content.parts[0].text == 'Hello'
775+
776+
777+
def test_process_chunk_with_refusal_streaming():
778+
handler = ChatCompletionsResponseHandler()
779+
780+
chunk1 = {
781+
'choices': [{
782+
'delta': {
783+
'role': 'assistant',
784+
'content': 'Hello',
785+
},
786+
'index': 0,
787+
}]
788+
}
789+
responses1 = list(handler.process_chunk(chunk1))
790+
assert len(responses1) == 1
791+
assert responses1[0].content.parts[0].text == 'Hello'
792+
793+
chunk2 = {
794+
'choices': [{
795+
'delta': {
796+
'refusal': 'I refuse',
797+
},
798+
'index': 0,
799+
}]
800+
}
801+
responses2 = list(handler.process_chunk(chunk2))
802+
assert len(responses2) == 1
803+
assert responses2[0].content.parts[0].text == '\n[[REFUSAL]]: I refuse'
804+
805+
chunk3 = {
806+
'choices': [{
807+
'delta': {
808+
'refusal': ' to answer',
809+
},
810+
'index': 0,
811+
}]
812+
}
813+
responses3 = list(handler.process_chunk(chunk3))
814+
assert len(responses3) == 1
815+
assert responses3[0].content.parts[0].text == ' to answer'
816+
817+
chunk4 = {
818+
'choices': [{
819+
'delta': {},
820+
'finish_reason': 'stop',
821+
'index': 0,
822+
}]
823+
}
824+
responses4 = list(handler.process_chunk(chunk4))
825+
assert len(responses4) == 2
826+
final_response = responses4[1]
827+
assert final_response.finish_reason == types.FinishReason.STOP
828+
assert (
829+
final_response.content.parts[0].text
830+
== 'Hello\n[[REFUSAL]]: I refuse to answer'
831+
)

0 commit comments

Comments
 (0)