Skip to content

Commit 6580f41

Browse files
committed
test(workflow): Add unit test for FunctionNode auth pause-and-resume
Change-Id: Idc2236f1b5efea54b6bf5965ed4ffe580f6a11e9
1 parent 70b533f commit 6580f41

1 file changed

Lines changed: 100 additions & 1 deletion

File tree

tests/unittests/workflow/test_workflow_llm_agent_interruptions.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2929
from google.adk.sessions.session import Session
3030
from google.adk.tools.long_running_tool import LongRunningFunctionTool
31+
from google.adk.tools.tool_context import ToolContext
3132
from google.adk.workflow import Edge
3233
from google.adk.workflow import START
3334
from google.adk.workflow import Workflow
@@ -219,8 +220,8 @@ async def test_workflow_pause_and_resume_tool_confirmation(
219220
- Run 1: Workflow pauses and yields confirmation request.
220221
- Run 2: Workflow resumes and completes with LLM response.
221222
"""
222-
from google.adk.tools.function_tool import FunctionTool
223223
from google.adk.flows.llm_flows.functions import REQUEST_CONFIRMATION_FUNCTION_CALL_NAME
224+
from google.adk.tools.function_tool import FunctionTool
224225

225226
# Given a tool that requires confirmation and a mock model
226227
def _simple_tool_func():
@@ -299,6 +300,104 @@ def _simple_tool_func():
299300
assert any('LLM response after confirmation' in t for t in content_texts)
300301

301302

303+
@pytest.mark.asyncio
304+
async def test_workflow_pause_and_resume_auth_node(
305+
request: pytest.FixtureRequest,
306+
):
307+
"""Workflow pauses on missing credentials and resumes when provided.
308+
309+
Setup: Workflow with a single node requiring auth.
310+
Act:
311+
- Run 1: Start workflow without credentials.
312+
- Run 2: Provide credentials via FunctionResponse.
313+
Assert:
314+
- Run 1: Workflow returns adk_request_credential request.
315+
- Run 2: Workflow completes and yields event with credential.
316+
"""
317+
from fastapi.openapi.models import APIKey
318+
from fastapi.openapi.models import APIKeyIn
319+
from google.adk.auth.auth_credential import AuthCredential
320+
from google.adk.auth.auth_credential import AuthCredentialTypes
321+
from google.adk.auth.auth_tool import AuthConfig
322+
from google.adk.workflow import FunctionNode
323+
324+
# Given a workflow with a node requiring auth
325+
auth_config = AuthConfig(
326+
auth_scheme=APIKey(**{'in': APIKeyIn.header, 'name': 'X-Api-Key'}),
327+
credential_key='my_key',
328+
)
329+
330+
def fetch_weather(ctx):
331+
cred = ctx.get_auth_response(auth_config)
332+
api_key = cred.api_key if cred else 'unknown'
333+
from google.adk import Event
334+
yield Event(message=f"authed with {api_key}")
335+
336+
node_a = FunctionNode(fetch_weather, auth_config=auth_config, rerun_on_resume=True)
337+
338+
wf = Workflow(
339+
name='test_workflow_auth_node',
340+
edges=[(START, node_a)],
341+
)
342+
343+
app = App(
344+
name=request.function.__name__,
345+
root_agent=wf,
346+
)
347+
runner = testing_utils.InMemoryRunner(app=app)
348+
349+
# When the workflow is started without credentials
350+
events1 = await runner.run_async(testing_utils.get_user_content('start'))
351+
352+
# Then it should pause and request credentials
353+
auth_fc_events = [
354+
e for e in events1
355+
if e.content
356+
and e.content.parts
357+
and e.content.parts[0].function_call
358+
and e.content.parts[0].function_call.name == "adk_request_credential"
359+
]
360+
assert len(auth_fc_events) == 1
361+
auth_fc_id = auth_fc_events[0].content.parts[0].function_call.id
362+
invocation_id = events1[0].invocation_id
363+
364+
# When the user provides the credentials
365+
auth_response = AuthConfig(
366+
auth_scheme=auth_config.auth_scheme,
367+
credential_key=auth_config.credential_key,
368+
exchanged_auth_credential=AuthCredential(
369+
auth_type=AuthCredentialTypes.API_KEY,
370+
api_key="secret_key",
371+
),
372+
)
373+
374+
user_credential_response = testing_utils.UserContent(
375+
types.Part(
376+
function_response=types.FunctionResponse(
377+
id=auth_fc_id,
378+
name="adk_request_credential",
379+
response=auth_response.model_dump(exclude_none=True, by_alias=True),
380+
),
381+
),
382+
)
383+
384+
# When the workflow is resumed
385+
events2 = await runner.run_async(
386+
new_message=user_credential_response,
387+
invocation_id=invocation_id,
388+
)
389+
390+
# Then the workflow should resume and complete
391+
content_texts = [
392+
p.text
393+
for e in events2
394+
if e.content and e.content.parts
395+
for p in e.content.parts
396+
if p.text
397+
]
398+
assert any('authed with secret_key' in t for t in content_texts)
399+
400+
302401
@pytest.mark.asyncio
303402
async def test_workflow_pause_and_resume_parent_interruption(
304403
request: pytest.FixtureRequest,

0 commit comments

Comments
 (0)