Skip to content

Commit 3efeee5

Browse files
committed
feat(workflow): Support JSON string parsing in schema validation
Automatically parse JSON strings into dicts or Pydantic models when input_schema or output_schema is defined on a node. This update refines the output schema validation to: 1. Try standard Pydantic validation first to preserve valid types (like raw strings or Content objects when requested). 2. Fallback to extracting text and parsing JSON ONLY if the data is a `types.Content` object and standard validation fails. This ensures robust handling of LLM outputs while avoiding accidental parsing of valid raw strings. Added comprehensive unit tests for both Content parsing and raw string validation failures. Change-Id: I68fe4c636365d4d3c1c458fbf863e18cfcbd8479
1 parent a5392c7 commit 3efeee5

2 files changed

Lines changed: 128 additions & 2 deletions

File tree

src/google/adk/workflow/_base_node.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,15 @@ def _validate_schema(self, data: Any, schema: Any) -> Any:
126126
def _validate_input_data(self, data: Any) -> Any:
127127
"""Validates data against input_schema if set."""
128128
if self.input_schema and isinstance(data, types.Content):
129-
text = "".join(part.text for part in data.parts if part.text)
129+
# Extract text from Content (e.g. user input from START node).
130+
text = ''.join(part.text for part in data.parts if part.text)
130131
if self.input_schema is str:
131132
return text
133+
# If schema is defined, try to parse the text as JSON.
132134
try:
133135
return TypeAdapter(self.input_schema).validate_json(text)
134136
except Exception:
137+
# Fallback to validate_python if it's a raw string matching the schema.
135138
try:
136139
return TypeAdapter(self.input_schema).validate_python(text)
137140
except Exception:
@@ -140,7 +143,27 @@ def _validate_input_data(self, data: Any) -> Any:
140143

141144
def _validate_output_data(self, data: Any) -> Any:
142145
"""Validates data against output_schema if set."""
143-
return self._validate_schema(data, self.output_schema)
146+
if not self.output_schema:
147+
return data
148+
149+
# 1. Try standard validation first
150+
try:
151+
return self._validate_schema(data, self.output_schema)
152+
except Exception as e:
153+
# 2. If failed, try to parse JSON ONLY if it's Content
154+
if isinstance(data, types.Content):
155+
text = ''.join(part.text for part in data.parts if part.text)
156+
if self.output_schema is str:
157+
return text
158+
if text.strip():
159+
try:
160+
validated = TypeAdapter(self.output_schema).validate_json(text)
161+
return self._to_serializable(validated)
162+
except Exception:
163+
pass
164+
165+
# 3. If not Content or parsing failed, re-raise original error
166+
raise e
144167

145168
@staticmethod
146169
def _to_serializable(data: Any) -> Any:

tests/unittests/workflow/test_workflow_schema.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,3 +668,106 @@ async def _run_impl(
668668

669669
data_events = [e for e in events if isinstance(e, Event) and e.output]
670670
assert any(e.output == 'done' for e in data_events)
671+
672+
673+
@pytest.mark.asyncio
674+
async def test_workflow_with_invalid_output_schema():
675+
"""Workflow raises ValidationError if terminal output doesn't match output_schema."""
676+
677+
from pydantic import ValidationError
678+
679+
class _MyModel(BaseModel):
680+
name: str
681+
682+
class _MyNode(BaseNode):
683+
684+
async def _run_impl(
685+
self, *, ctx: Context, node_input: Any
686+
) -> AsyncGenerator[Any, None]:
687+
yield {'age': 10}
688+
689+
node = _MyNode(name='node')
690+
wf = Workflow(name='wf', edges=[(START, node)], output_schema=_MyModel)
691+
692+
ss = InMemorySessionService()
693+
runner = Runner(app_name='test', node=wf, session_service=ss)
694+
session = await ss.create_session(app_name='test', user_id='u')
695+
696+
msg = types.Content(parts=[types.Part(text='hello')], role='user')
697+
698+
with pytest.raises(ValidationError):
699+
async for event in runner.run_async(
700+
user_id='u', session_id=session.id, new_message=msg
701+
):
702+
pass
703+
704+
705+
@pytest.mark.asyncio
706+
async def test_node_returns_content_json_parsed():
707+
"""Node output as types.Content containing JSON is parsed if output_schema is defined."""
708+
709+
class _MyModel(BaseModel):
710+
name: str
711+
age: int
712+
713+
class _MyNode(BaseNode):
714+
715+
async def _run_impl(
716+
self, *, ctx: Context, node_input: Any
717+
) -> AsyncGenerator[Any, None]:
718+
yield self._validate_output_data(
719+
types.Content(parts=[types.Part(text='{"name": "Alice", "age": 30}')])
720+
)
721+
722+
node = _MyNode(name='node', output_schema=_MyModel)
723+
wf = Workflow(name='wf', edges=[(START, node)])
724+
725+
ss = InMemorySessionService()
726+
runner = Runner(app_name='test', node=wf, session_service=ss)
727+
session = await ss.create_session(app_name='test', user_id='u')
728+
729+
msg = types.Content(parts=[types.Part(text='hello')], role='user')
730+
events = []
731+
732+
async for event in runner.run_async(
733+
user_id='u', session_id=session.id, new_message=msg
734+
):
735+
events.append(event)
736+
737+
data_events = [e for e in events if isinstance(e, Event) and e.output]
738+
739+
assert len(data_events) == 1
740+
assert data_events[0].output == {'name': 'Alice', 'age': 30}
741+
742+
743+
@pytest.mark.asyncio
744+
async def test_node_returns_raw_string_not_parsed():
745+
"""Node output as raw JSON string is NOT parsed if output_schema is defined."""
746+
from pydantic import ValidationError
747+
748+
class _MyModel(BaseModel):
749+
name: str
750+
age: int
751+
752+
class _MyNode(BaseNode):
753+
754+
async def _run_impl(
755+
self, *, ctx: Context, node_input: Any
756+
) -> AsyncGenerator[Any, None]:
757+
# This should fail validation because it's a string, not a dict/model
758+
yield self._validate_output_data('{"name": "Alice", "age": 30}')
759+
760+
node = _MyNode(name='node', output_schema=_MyModel)
761+
wf = Workflow(name='wf', edges=[(START, node)])
762+
763+
ss = InMemorySessionService()
764+
runner = Runner(app_name='test', node=wf, session_service=ss)
765+
session = await ss.create_session(app_name='test', user_id='u')
766+
767+
msg = types.Content(parts=[types.Part(text='hello')], role='user')
768+
769+
with pytest.raises(ValidationError):
770+
async for _ in runner.run_async(
771+
user_id='u', session_id=session.id, new_message=msg
772+
):
773+
pass

0 commit comments

Comments
 (0)