Skip to content

Commit d7de2f7

Browse files
committed
feat(workflow): Support parsing input schema from Content text
Change-Id: Id20483da220c7b0d84bea68674f6614657a2dc42
1 parent 61a8b5b commit d7de2f7

2 files changed

Lines changed: 63 additions & 6 deletions

File tree

src/google/adk/workflow/_base_node.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,17 @@ def _validate_schema(self, data: Any, schema: Any) -> Any:
125125

126126
def _validate_input_data(self, data: Any) -> Any:
127127
"""Validates data against input_schema if set."""
128-
if self.input_schema is str and isinstance(data, types.Content):
129-
return ''.join(part.text for part in data.parts if part.text)
128+
if self.input_schema and isinstance(data, types.Content):
129+
text = "".join(part.text for part in data.parts if part.text)
130+
if self.input_schema is str:
131+
return text
132+
try:
133+
return TypeAdapter(self.input_schema).validate_json(text)
134+
except Exception:
135+
try:
136+
return TypeAdapter(self.input_schema).validate_python(text)
137+
except Exception:
138+
pass
130139
return self._validate_schema(data, self.input_schema)
131140

132141
def _validate_output_data(self, data: Any) -> Any:

tests/unittests/workflow/test_workflow_schema.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -543,18 +543,66 @@ async def _run_impl(
543543
assert any(e.output == 'done' for e in data_events)
544544

545545

546-
@pytest.mark.xfail(reason='Input schema parsing not yet in new Workflow.')
547546
@pytest.mark.asyncio
548547
async def test_start_node_with_int_input_schema():
549548
"""input_schema=int parses user text to int."""
550-
assert False, 'TODO'
549+
550+
class _AssertingNode(BaseNode):
551+
552+
async def _run_impl(
553+
self, *, ctx: Context, node_input: Any
554+
) -> AsyncGenerator[Any, None]:
555+
assert node_input == 42
556+
yield 'done'
557+
558+
node = _AssertingNode(name='node', input_schema=int)
559+
wf = Workflow(name='wf', edges=[(START, node)])
560+
561+
ss = InMemorySessionService()
562+
runner = Runner(app_name='test', node=wf, session_service=ss)
563+
session = await ss.create_session(app_name='test', user_id='u')
564+
565+
msg = types.Content(parts=[types.Part(text='42')], role='user')
566+
events = []
567+
568+
async for event in runner.run_async(
569+
user_id='u', session_id=session.id, new_message=msg
570+
):
571+
events.append(event)
572+
573+
data_events = [e for e in events if isinstance(e, Event) and e.output]
574+
assert any(e.output == 'done' for e in data_events)
551575

552576

553-
@pytest.mark.xfail(reason='Input schema parsing not yet in new Workflow.')
554577
@pytest.mark.asyncio
555578
async def test_start_node_with_int_list_input_schema():
556579
"""input_schema=list[int] parses JSON list."""
557-
assert False, 'TODO'
580+
581+
class _AssertingNode(BaseNode):
582+
583+
async def _run_impl(
584+
self, *, ctx: Context, node_input: Any
585+
) -> AsyncGenerator[Any, None]:
586+
assert node_input == [1, 2, 3]
587+
yield 'done'
588+
589+
node = _AssertingNode(name='node', input_schema=list[int])
590+
wf = Workflow(name='wf', edges=[(START, node)])
591+
592+
ss = InMemorySessionService()
593+
runner = Runner(app_name='test', node=wf, session_service=ss)
594+
session = await ss.create_session(app_name='test', user_id='u')
595+
596+
msg = types.Content(parts=[types.Part(text='[1, 2, 3]')], role='user')
597+
events = []
598+
599+
async for event in runner.run_async(
600+
user_id='u', session_id=session.id, new_message=msg
601+
):
602+
events.append(event)
603+
604+
data_events = [e for e in events if isinstance(e, Event) and e.output]
605+
assert any(e.output == 'done' for e in data_events)
558606

559607

560608
@pytest.mark.asyncio

0 commit comments

Comments
 (0)