Skip to content

Commit 7127b61

Browse files
committed
fix(workflow): Restore run_counter during rehydration to ensure unique loop paths
When a workflow is rehydrated after an interruption, the `run_counter` in `NodeState` was defaulting to 0 because it was not explicitly restored from events. In loop scenarios (e.g. HITL loops), this caused the engine to generate duplicate `run_id` paths (like `@1`) when it looped back and triggered the node again, leading to path collisions and state carryover bugs. This fix explicitly restores `run_counter` from the event history during `_restore_static_nodes_from_events`, ensuring that subsequent runs of the same node in a loop generate unique sequential paths (like `@2`, `@3`). This ensures correctness in human-in-the-loop samples like `request_input` where the execution loops back based on user feedback. Change-Id: I2a8e9d0797b47a441261b4d1aa3ab932424a8b54
1 parent 6580f41 commit 7127b61

2 files changed

Lines changed: 133 additions & 10 deletions

File tree

src/google/adk/workflow/_workflow_class.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ def _restore_static_nodes_from_events(
536536
unresolved = child.interrupt_ids - child.resolved_ids
537537
existing_evt_run_id = child.run_id
538538

539+
run_counter = int(existing_evt_run_id) if existing_evt_run_id else 0
539540
if unresolved:
540541
node = self._get_static_node_by_name(child_name)
541542
if node.rerun_on_resume and child.resolved_ids:
@@ -547,6 +548,7 @@ def _restore_static_nodes_from_events(
547548
status=NodeStatus.PENDING,
548549
resume_inputs=child.resolved_responses,
549550
run_id=existing_evt_run_id,
551+
run_counter=run_counter,
550552
)
551553
else:
552554
# Child can't handle partial resume, or nothing resolved
@@ -555,12 +557,14 @@ def _restore_static_nodes_from_events(
555557
status=NodeStatus.WAITING,
556558
interrupts=list(unresolved),
557559
run_id=existing_evt_run_id,
560+
run_counter=run_counter,
558561
)
559562
elif child.output is not None:
560563
# Node's all interrupts are resolved and had output in previous run.
561564
nodes[child_name] = NodeState(
562565
status=NodeStatus.COMPLETED,
563566
run_id=existing_evt_run_id,
567+
run_counter=run_counter,
564568
)
565569
node_outputs[child_name] = child.output
566570
elif child.interrupt_ids:
@@ -570,17 +574,35 @@ def _restore_static_nodes_from_events(
570574
nodes[child_name] = NodeState(
571575
status=NodeStatus.COMPLETED,
572576
run_id=existing_evt_run_id,
577+
run_counter=run_counter,
573578
)
574579
node_outputs[child_name] = self._extract_resume_output(child, ctx)
575-
576580
# Mark that we need to trigger downstream for this node
577581
nodes_to_trigger.append((child_name, node_outputs[child_name]))
578582
else:
579583
nodes[child_name] = NodeState(
580584
status=NodeStatus.PENDING,
581585
resume_inputs=child.resolved_responses,
582586
run_id=existing_evt_run_id,
587+
run_counter=run_counter,
583588
)
589+
if child_name not in nodes:
590+
is_wait_for_output = False
591+
try:
592+
node = self._get_static_node_by_name(child_name)
593+
is_wait_for_output = node.wait_for_output
594+
except ValueError:
595+
pass
596+
597+
# For nodes with events but no output:
598+
# If wait_for_output is True, they are still WAITING for output.
599+
# Otherwise, they are considered COMPLETED (e.g., side-effect nodes).
600+
status = NodeStatus.WAITING if is_wait_for_output and child.output is None else NodeStatus.COMPLETED
601+
nodes[child_name] = NodeState(
602+
status=status,
603+
run_id=existing_evt_run_id,
604+
run_counter=run_counter,
605+
)
584606

585607
# wait_for_output nodes that were triggered but produced no output
586608
self._add_wait_for_output_nodes(nodes, children)
@@ -599,13 +621,6 @@ def _restore_static_nodes_from_events(
599621
for interrupt_id in state.interrupts
600622
}
601623

602-
# Restore run_counter from run_id so resumed nodes continue
603-
# sequential ids. When NodeState is persisted (resumability=on),
604-
# run_counter will already be correct from deserialization.
605-
for state in nodes.values():
606-
if state.run_id and state.run_id.isdigit():
607-
state.run_counter = int(state.run_id)
608-
609624
logger.info('node %s rehydrate end.', ctx.node_path)
610625

611626
def _extract_resume_output(self, child: _ChildScanState, ctx: Context) -> Any:
@@ -686,7 +701,7 @@ def _scan_child_events(self, ctx: Context) -> dict[str, _ChildScanState]:
686701

687702
# New run_id → reset child state (previous run stale).
688703
# ONLY update run_id from direct child events, not descendants!
689-
evt_run_id = event.node_info.run_id
704+
evt_run_id = event.node_info.path.rsplit('@', 1)[-1] if '@' in event.node_info.path else ''
690705
if (
691706
is_direct_child(event.node_info.path, workflow_path)
692707
and evt_run_id
@@ -751,6 +766,8 @@ def _process_resume(self, loop_state: _LoopState, ctx: Context) -> None:
751766
"""Seed triggers for PENDING nodes and collect interrupt IDs."""
752767
for node_name, node_state in loop_state.nodes.items():
753768
if node_state.status == NodeStatus.PENDING:
769+
if node_name in loop_state.trigger_buffer:
770+
continue
754771
loop_state.trigger_buffer.setdefault(node_name, []).append(
755772
Trigger(
756773
input=node_state.input,

tests/unittests/workflow/test_workflow_hitl.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@
3636
from google.adk.workflow import START
3737
from google.adk.workflow._node_status import NodeStatus
3838
from google.adk.workflow._workflow_class import Workflow
39-
from google.adk.workflow.utils._workflow_hitl_utils import REQUEST_CREDENTIAL_FUNCTION_CALL_NAME
4039
from google.adk.workflow.utils._workflow_hitl_utils import create_request_input_response
4140
from google.adk.workflow.utils._workflow_hitl_utils import get_request_input_interrupt_ids
41+
from google.adk.workflow.utils._workflow_hitl_utils import REQUEST_CREDENTIAL_FUNCTION_CALL_NAME
4242
from google.adk.workflow.utils._workflow_hitl_utils import REQUEST_INPUT_FUNCTION_CALL_NAME
4343
from google.adk.workflow.utils._workflow_hitl_utils import wrap_response
4444
from google.genai import types
@@ -1378,6 +1378,7 @@ def process():
13781378
events1 = await runner.run_async(testing_utils.get_user_content('go'))
13791379
req1 = workflow_testing_utils.get_request_input_events(events1)
13801380
assert len(req1) == 1
1381+
assert 'review@1' in req1[0].node_info.path
13811382
inv_id = events1[0].invocation_id
13821383

13831384
# Turn 2: revise → process reruns → review reruns → interrupt again
@@ -1389,6 +1390,7 @@ def process():
13891390
)
13901391
req2 = workflow_testing_utils.get_request_input_events(events2)
13911392
assert len(req2) == 1, 'Expected second interrupt after revise'
1393+
assert 'review@2' in req2[0].node_info.path
13921394
inv_id = events2[0].invocation_id
13931395

13941396
# Turn 3: approve → should complete, not loop
@@ -1580,3 +1582,107 @@ def second_task():
15801582
# Both nodes ran — node_b did NOT pause for a second auth request.
15811583
assert call_log == ['first', 'second']
15821584
assert sink.received_inputs == [{'status': 'done'}]
1585+
1586+
1587+
@pytest.mark.asyncio
1588+
async def test_workflow_loop_generates_unique_paths_across_resume(
1589+
request: pytest.FixtureRequest
1590+
):
1591+
"""Workflow loop generates unique sequential paths across resumes.
1592+
1593+
Setup: workflow simulating request_input sample with a loop and a RequestInput node.
1594+
Act:
1595+
- Turn 1: trigger RequestInput and interrupt.
1596+
- Turn 2: provide response triggering a loop back, and trigger RequestInput again.
1597+
Assert:
1598+
- Turn 1: node path has @1.
1599+
- Turn 2: node path has @2.
1600+
"""
1601+
from google.adk.workflow import node
1602+
from google.adk.apps import App
1603+
from google.adk.events.event import Event
1604+
from google.adk.events.request_input import RequestInput
1605+
1606+
from tests.unittests import testing_utils
1607+
from tests.unittests.workflow import workflow_testing_utils
1608+
1609+
# Given a workflow simulating the request_input sample
1610+
@node
1611+
def process_input(node_input: Any):
1612+
yield Event(state={"complaint": node_input, "feedback": ""})
1613+
1614+
@node
1615+
def draft_email(ctx: Context):
1616+
complaint = ctx.state.get('complaint')
1617+
feedback = ctx.state.get('feedback')
1618+
yield Event(output=f"Draft based on {complaint} and feedback {feedback}")
1619+
1620+
@node(rerun_on_resume=True)
1621+
def request_human_review(node_input: Any, ctx: Context):
1622+
resume = ctx.resume_inputs.get('human_review')
1623+
if not resume:
1624+
yield RequestInput(
1625+
interrupt_id='human_review',
1626+
message=f"Please review: {node_input}",
1627+
)
1628+
return
1629+
yield Event(output=resume)
1630+
1631+
request_human_review.wait_for_output = True
1632+
1633+
@node
1634+
def handle_human_review(node_input: Any):
1635+
result = node_input.get('result') if isinstance(node_input, dict) else node_input
1636+
if result == "approve":
1637+
yield Event(route="approved")
1638+
else:
1639+
yield Event(state={"feedback": result}, route="revise")
1640+
1641+
@node
1642+
def end_node(node_input: Any):
1643+
yield Event(output="done")
1644+
1645+
wf = Workflow(
1646+
name="request_input",
1647+
edges=[
1648+
(
1649+
START,
1650+
process_input,
1651+
draft_email,
1652+
request_human_review,
1653+
handle_human_review,
1654+
),
1655+
(handle_human_review, {"revise": draft_email, "approved": end_node}),
1656+
],
1657+
)
1658+
1659+
app = App(
1660+
name=request.function.__name__,
1661+
root_agent=wf,
1662+
)
1663+
runner = testing_utils.InMemoryRunner(app=app)
1664+
1665+
# When Turn 1 executes (starts and interrupts)
1666+
events1 = await runner.run_async(
1667+
testing_utils.get_user_content("my complaint")
1668+
)
1669+
1670+
# Then verify it interrupted at request_human_review@1
1671+
req1 = workflow_testing_utils.get_request_input_events(events1)
1672+
assert len(req1) == 1
1673+
assert 'request_human_review@1' in req1[0].node_info.path
1674+
1675+
inv_id = events1[0].invocation_id
1676+
1677+
# When Turn 2 executes (provides response and loops back)
1678+
events2 = await runner.run_async(
1679+
new_message=testing_utils.UserContent(
1680+
create_request_input_response('human_review', {'result': 'make it shorter'})
1681+
),
1682+
invocation_id=inv_id,
1683+
)
1684+
1685+
# Then verify it triggered request_human_review again with run_id @2
1686+
req2 = workflow_testing_utils.get_request_input_events(events2)
1687+
assert len(req2) == 1
1688+
assert 'request_human_review@2' in req2[0].node_info.path

0 commit comments

Comments
 (0)