Skip to content

Commit 243720f

Browse files
committed
test(workflow): Ensure unique paths and state persistence in loop tests
Consolidated regression tests for wait_for_output nodes in loops. Migrated wait_for_output logic from HITL tests to core workflow tests. Verified path uniqueness (run_id) and state rehydration. Extracted a shared helper for function call interruption simulation. Change-Id: Ie9c19a93d0ed71941301865352641c08e5657273
1 parent 2368fb8 commit 243720f

2 files changed

Lines changed: 155 additions & 140 deletions

File tree

tests/unittests/workflow/test_workflow_class.py

Lines changed: 154 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,21 @@
4242
# Shared helper nodes (used by multiple tests)
4343
# ---------------------------------------------------------------------------
4444

45+
def _make_function_call_interrupt(fc_id: str, name: str = 'approve') -> Event:
46+
"""Helper to create a raw function call interruption event."""
47+
return Event(
48+
content=types.Content(
49+
parts=[
50+
types.Part(
51+
function_call=types.FunctionCall(
52+
name=name, args={}, id=fc_id
53+
)
54+
)
55+
]
56+
),
57+
long_running_tool_ids={fc_id},
58+
)
59+
4560

4661
class _OutputNode(BaseNode):
4762
"""Yields a fixed output value."""
@@ -1698,21 +1713,17 @@ def outer_node(node_input: str) -> str:
16981713

16991714

17001715
@pytest.mark.asyncio
1701-
async def test_wait_for_output_node_preserved_across_resume():
1702-
"""wait_for_output node is marked WAITING on resume, not re-triggered.
1703-
1704-
Scenario:
1705-
START → A (completes), B (interrupts)
1706-
A → Gate(wait_for_output, needs 2 triggers)
1707-
B → Gate
1708-
Gate → Downstream
1709-
1710-
Run 1: A completes → triggers Gate (1/2, no output). B interrupts.
1711-
Run 2: Resume B → B completes → triggers Gate (2/2, outputs).
1712-
Gate opens → Downstream runs.
1713-
1714-
Without _add_wait_for_output_nodes, Gate would be treated as fresh
1715-
on resume and only receive 1 trigger (from B), never opening.
1716+
async def test_wait_for_output_node_preserves_state_across_resume():
1717+
"""Wait-for-output node preserves received triggers across workflow resume.
1718+
1719+
Setup:
1720+
START -> NodeA (completes) -> Gate (wait_for_output, needs 2 triggers).
1721+
START -> NodeB (interrupts) -> Gate.
1722+
Act:
1723+
- Turn 1: Start workflow. NodeA completes, NodeB interrupts. Gate waits (1/2).
1724+
- Turn 2: Resume with NodeB response. NodeB completes, triggers Gate (2/2).
1725+
Assert:
1726+
- Gate opens in Turn 2 and produces output.
17161727
"""
17171728

17181729
class _InterruptOnce(BaseNode):
@@ -1728,18 +1739,7 @@ async def _run_impl(
17281739
return
17291740
fc_id = f'fc-{uuid.uuid4().hex[:8]}'
17301741
ctx.state['_fc_id'] = fc_id
1731-
yield Event(
1732-
content=types.Content(
1733-
parts=[
1734-
types.Part(
1735-
function_call=types.FunctionCall(
1736-
name='approve', args={}, id=fc_id
1737-
)
1738-
)
1739-
]
1740-
),
1741-
long_running_tool_ids={fc_id},
1742-
)
1742+
yield _make_function_call_interrupt(fc_id)
17431743

17441744
a = _OutputNode(name='NodeA', value='A')
17451745
b = _InterruptOnce(name='NodeB')
@@ -1802,6 +1802,125 @@ async def _run_impl(
18021802
assert 'done' in outputs
18031803

18041804

1805+
@pytest.mark.asyncio
1806+
async def test_wait_for_output_node_in_loop_generates_unique_paths():
1807+
"""Wait-for-output node in loop generates unique paths across iterations.
1808+
1809+
Setup:
1810+
Workflow with a loop: START -> Process -> [NodeA, NodeB] -> Join -> Handle.
1811+
Handle loops back to Process on first iteration, exits on second.
1812+
NodeB interrupts on first run of each iteration.
1813+
Act:
1814+
- Turn 1: Start workflow. NodeA completes, NodeB interrupts. Join waits.
1815+
- Turn 2: Resume with NodeB response. Handle loops back. Process runs again.
1816+
NodeA completes, NodeB interrupts again.
1817+
- Turn 3: Resume with NodeB response again. Join opens.
1818+
Assert:
1819+
- JoinNode runs with path ending in `Join@2` in the second iteration.
1820+
"""
1821+
1822+
1823+
class _InterruptOnce(BaseNode):
1824+
rerun_on_resume: bool = True
1825+
1826+
async def _run_impl(
1827+
self, *, ctx: Context, node_input: Any
1828+
) -> AsyncGenerator[Any, None]:
1829+
fc_id = ctx.state.get('_fc_id')
1830+
if fc_id and ctx.resume_inputs and fc_id in ctx.resume_inputs:
1831+
ctx.state['_fc_id'] = None
1832+
yield 'B_done'
1833+
return
1834+
fc_id = 'fc-interrupt'
1835+
ctx.state['_fc_id'] = fc_id
1836+
yield _make_function_call_interrupt(fc_id)
1837+
1838+
class _HandleNode(BaseNode):
1839+
1840+
async def _run_impl(
1841+
self, *, ctx: Context, node_input: Any
1842+
) -> AsyncGenerator[Any, None]:
1843+
# Loop back if this is the first iteration
1844+
count = ctx.state.get('loop_count', 0) + 1
1845+
ctx.state['loop_count'] = count
1846+
if count == 1:
1847+
yield Event(route='loop')
1848+
else:
1849+
yield Event(route='exit', output='finished')
1850+
1851+
process = _PassthroughNode(name='Process')
1852+
a = _OutputNode(name='NodeA', value='A')
1853+
b = _InterruptOnce(name='NodeB')
1854+
join = JoinNode(name='Join')
1855+
handle = _HandleNode(name='Handle')
1856+
exit_node = _PassthroughNode(name='Exit')
1857+
1858+
wf = Workflow(
1859+
name='loop_wf',
1860+
edges=[
1861+
(START, process),
1862+
(process, a),
1863+
(process, b),
1864+
(a, join),
1865+
(b, join),
1866+
(join, handle),
1867+
(handle, {'loop': process, 'exit': exit_node}),
1868+
],
1869+
)
1870+
1871+
ss = InMemorySessionService()
1872+
runner = Runner(app_name='test', node=wf, session_service=ss)
1873+
session = await ss.create_session(app_name='test', user_id='u')
1874+
1875+
# Turn 1: process runs, A completes, B interrupts, Join waits
1876+
msg1 = types.Content(parts=[types.Part(text='go')], role='user')
1877+
events1: list[Event] = []
1878+
async for event in runner.run_async(
1879+
user_id='u', session_id=session.id, new_message=msg1
1880+
):
1881+
events1.append(event)
1882+
1883+
# Verify paths in Turn 1
1884+
join_events1 = [
1885+
e for e in events1 if e.node_info and 'Join' in e.node_info.path
1886+
]
1887+
# JoinNode does not yield events until all inputs are collected.
1888+
1889+
# Turn 2: Provide response for NodeB (InterruptNode)
1890+
msg2 = types.Content(
1891+
parts=[
1892+
types.Part(
1893+
function_response=types.FunctionResponse(
1894+
name='approve', id='fc-interrupt', response={'ok': True}
1895+
)
1896+
)
1897+
],
1898+
role='user',
1899+
)
1900+
1901+
events2: list[Event] = []
1902+
async for event in runner.run_async(
1903+
user_id='u', session_id=session.id, new_message=msg2
1904+
):
1905+
events2.append(event)
1906+
1907+
# Workflow loops back. NodeB interrupts again in the second iteration.
1908+
1909+
# Turn 3: Provide response for NodeB again!
1910+
events3: list[Event] = []
1911+
async for event in runner.run_async(
1912+
user_id='u', session_id=session.id, new_message=msg2
1913+
):
1914+
events3.append(event)
1915+
1916+
# JoinNode opens in the second iteration and produces output.
1917+
1918+
join_events3 = [
1919+
e for e in events3 if e.node_info and 'Join@2' in e.node_info.path
1920+
]
1921+
assert len(join_events3) > 0, "JoinNode should run again in loop with @2"
1922+
1923+
18051924
# --- run_id reuse on resume ---
18061925

18071926

@@ -1822,7 +1941,7 @@ async def _run_impl(
18221941
return
18231942
fc_id = str(uuid.uuid4())
18241943
ctx.state['_fc_id'] = fc_id
1825-
yield RequestInput(interrupt_id=fc_id)
1944+
yield _make_function_call_interrupt(fc_id)
18261945

18271946
wf = Workflow(
18281947
name='wf',
@@ -1852,7 +1971,13 @@ async def _run_impl(
18521971

18531972
# Run 2: resume
18541973
msg2 = types.Content(
1855-
parts=[create_request_input_response(fc_id, {'ok': True})],
1974+
parts=[
1975+
types.Part(
1976+
function_response=types.FunctionResponse(
1977+
name='approve', id=fc_id, response={'ok': True}
1978+
)
1979+
)
1980+
],
18561981
role='user',
18571982
)
18581983
events2: list[Event] = []
@@ -1871,6 +1996,3 @@ async def _run_impl(
18711996
]
18721997
assert len(resumed_events) == 1
18731998
assert resumed_events[0].node_info.run_id == original_run_id
1874-
1875-
1876-

tests/unittests/workflow/test_workflow_hitl.py

Lines changed: 1 addition & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from google.adk.tools.long_running_tool import LongRunningFunctionTool
3434
from google.adk.workflow import BaseNode
3535
from google.adk.workflow import Edge
36+
from google.adk.workflow import node
3637
from google.adk.workflow import START
3738
from google.adk.workflow._node_status import NodeStatus
3839
from google.adk.workflow._workflow_class import Workflow
@@ -1297,8 +1298,6 @@ async def test_request_input_rerun_with_same_interrupt_id(
12971298
interrupt appear "already resolved", causing the workflow to
12981299
restart from scratch instead of resuming.
12991300
"""
1300-
from google.adk.workflow import node
1301-
13021301
@node(rerun_on_resume=True)
13031302
def review(ctx: Context):
13041303
resume = ctx.resume_inputs.get('review')
@@ -1468,8 +1467,6 @@ async def test_second_auth_node_skips_auth_when_credential_exists(
14681467
from google.adk.auth.auth_credential import AuthCredential
14691468
from google.adk.auth.auth_credential import AuthCredentialTypes
14701469
from google.adk.auth.auth_tool import AuthConfig
1471-
from google.adk.workflow import node
1472-
14731470
auth_config = AuthConfig(
14741471
auth_scheme=APIKey(**{'in': APIKeyIn.header, 'name': 'X-Api-Key'}),
14751472
raw_auth_credential=AuthCredential(
@@ -1541,107 +1538,3 @@ def second_task():
15411538
# Both nodes ran — node_b did NOT pause for a second auth request.
15421539
assert call_log == ['first', 'second']
15431540
assert sink.received_inputs == [{'status': 'done'}]
1544-
1545-
1546-
@pytest.mark.asyncio
1547-
async def test_workflow_loop_generates_unique_paths_across_resume(
1548-
request: pytest.FixtureRequest
1549-
):
1550-
"""Workflow loop generates unique sequential paths across resumes.
1551-
1552-
Setup: workflow simulating request_input sample with a loop and a RequestInput node.
1553-
Act:
1554-
- Turn 1: trigger RequestInput and interrupt.
1555-
- Turn 2: provide response triggering a loop back, and trigger RequestInput again.
1556-
Assert:
1557-
- Turn 1: node path has @1.
1558-
- Turn 2: node path has @2.
1559-
"""
1560-
from google.adk.workflow import node
1561-
from google.adk.apps import App
1562-
from google.adk.events.event import Event
1563-
from google.adk.events.request_input import RequestInput
1564-
1565-
from tests.unittests import testing_utils
1566-
from tests.unittests.workflow import workflow_testing_utils
1567-
1568-
# Given a workflow simulating the request_input sample
1569-
@node
1570-
def process_input(node_input: Any):
1571-
yield Event(state={"complaint": node_input, "feedback": ""})
1572-
1573-
@node
1574-
def draft_email(ctx: Context):
1575-
complaint = ctx.state.get('complaint')
1576-
feedback = ctx.state.get('feedback')
1577-
yield Event(output=f"Draft based on {complaint} and feedback {feedback}")
1578-
1579-
@node(rerun_on_resume=True)
1580-
def request_human_review(node_input: Any, ctx: Context):
1581-
resume = ctx.resume_inputs.get('human_review')
1582-
if not resume:
1583-
yield RequestInput(
1584-
interrupt_id='human_review',
1585-
message=f"Please review: {node_input}",
1586-
)
1587-
return
1588-
yield Event(output=resume)
1589-
1590-
request_human_review.wait_for_output = True
1591-
1592-
@node
1593-
def handle_human_review(node_input: Any):
1594-
result = node_input.get('result') if isinstance(node_input, dict) else node_input
1595-
if result == "approve":
1596-
yield Event(route="approved")
1597-
else:
1598-
yield Event(state={"feedback": result}, route="revise")
1599-
1600-
@node
1601-
def end_node(node_input: Any):
1602-
yield Event(output="done")
1603-
1604-
wf = Workflow(
1605-
name="request_input",
1606-
edges=[
1607-
(
1608-
START,
1609-
process_input,
1610-
draft_email,
1611-
request_human_review,
1612-
handle_human_review,
1613-
),
1614-
(handle_human_review, {"revise": draft_email, "approved": end_node}),
1615-
],
1616-
)
1617-
1618-
app = App(
1619-
name=request.function.__name__,
1620-
root_agent=wf,
1621-
)
1622-
runner = testing_utils.InMemoryRunner(app=app)
1623-
1624-
# When Turn 1 executes (starts and interrupts)
1625-
events1 = await runner.run_async(
1626-
testing_utils.get_user_content("my complaint")
1627-
)
1628-
1629-
# Then verify it interrupted at request_human_review@1
1630-
req1 = workflow_testing_utils.get_request_input_events(events1)
1631-
assert len(req1) == 1
1632-
assert 'request_human_review@1' in req1[0].node_info.path
1633-
1634-
inv_id = events1[0].invocation_id
1635-
1636-
# When Turn 2 executes (provides response and loops back)
1637-
events2 = await runner.run_async(
1638-
new_message=testing_utils.UserContent(
1639-
create_request_input_response('human_review', {'result': 'make it shorter'})
1640-
),
1641-
invocation_id=inv_id,
1642-
)
1643-
1644-
# Then verify it triggered request_human_review again with run_id @2
1645-
req2 = workflow_testing_utils.get_request_input_events(events2)
1646-
assert len(req2) == 1
1647-
assert 'request_human_review@2' in req2[0].node_info.path

0 commit comments

Comments
 (0)