From 88827b60015c3266541c3d7ccefbec9baf062201 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Fri, 15 May 2026 10:49:46 -0600 Subject: [PATCH 1/4] Add type-aware json serialization/object encoding --- CHANGELOG.md | 4 + .../decorators/durable_app.py | 20 +- .../models/DurableEntityContext.py | 39 +- .../models/DurableOrchestrationClient.py | 4 +- .../models/DurableOrchestrationContext.py | 79 ++- .../models/OrchestratorState.py | 4 +- azure/durable_functions/models/Task.py | 7 + .../models/TaskOrchestrationExecutor.py | 18 +- .../models/actions/CallActivityAction.py | 5 +- .../actions/CallActivityWithRetryAction.py | 5 +- .../models/actions/CallEntityAction.py | 5 +- .../actions/CallSubOrchestratorAction.py | 5 +- .../CallSubOrchestratorWithRetryAction.py | 5 +- .../models/actions/ContinueAsNewAction.py | 5 +- .../models/actions/SignalEntityAction.py | 5 +- .../models/entities/EntityState.py | 4 +- .../models/entities/OperationResult.py | 5 +- .../models/utils/df_serialization.py | 226 ++++++ .../models/utils/type_discovery.py | 83 +++ azure/durable_functions/orchestrator.py | 7 +- tests/models/test_Decorators.py | 16 + .../test_DurableOrchestrationContext.py | 71 ++ tests/orchestrator/test_expected_type.py | 164 +++++ tests/orchestrator/test_external_event.py | 36 +- tests/utils/test_df_serialization.py | 656 ++++++++++++++++++ tests/utils/test_type_discovery.py | 81 +++ 26 files changed, 1497 insertions(+), 62 deletions(-) create mode 100644 azure/durable_functions/models/utils/df_serialization.py create mode 100644 azure/durable_functions/models/utils/type_discovery.py create mode 100644 tests/orchestrator/test_expected_type.py create mode 100644 tests/utils/test_df_serialization.py create mode 100644 tests/utils/test_type_discovery.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 3216b2c..d9af378 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ All notable changes to this project will be documented in this file. ### Added - Client operation correlation logging: `FunctionInvocationId` is now propagated via HTTP headers to the host for client operations, enabling correlation with host logs. +- Centralized JSON serialization module (`azure.durable_functions.models.utils.df_serialization`): all serialization/deserialization of user payloads (orchestrator inputs/outputs, activity arguments and results, sub-orchestrator payloads, entity inputs/outputs, and client inputs) now flows through `df_dumps` / `df_loads`, replacing scattered `json.dumps(…, default=_serialize_custom_object)` / `json.loads(…, object_hook=_deserialize_custom_object)` calls. The wire format is **unchanged** — builtins serialize to plain JSON and custom objects continue to use the `{"__class__", "__module__", "__data__"}` convention. +- Type-hint-driven validation via `df_loads(s, expected_type=...)`: when the V2 programming model provides a return-type annotation for an activity or sub-orchestrator, `df_loads` validates the deserialized payload against that type **before** the legacy `object_hook` fires, catching class/module mismatches early. +- **Strict typing mode** (opt-in via `AZURE_FUNCTIONS_DURABLE_STRICT_TYPING=1`): when enabled, `import_module` is never called on either encode or decode. On encode, `df_dumps` wraps only the top-level custom object — `to_json()` must return plain-JSON-serializable data (nested custom objects must be serialized explicitly). On decode, `df_loads` calls `expected_type.from_json(raw["__data__"])` directly; `df_loads` without `expected_type` raises `TypeError` for custom-object payloads. A `TypeError` is also raised on type mismatch. +- Return-type discovery for V2 decorated activities/sub-orchestrators (`azure.durable_functions.models.utils.type_discovery`): resolves the concrete return annotation from the user's registered function, used to supply `expected_type` to `df_loads`. ## 1.0.0b6 diff --git a/azure/durable_functions/decorators/durable_app.py b/azure/durable_functions/decorators/durable_app.py index 2b74a04..b885e4a 100644 --- a/azure/durable_functions/decorators/durable_app.py +++ b/azure/durable_functions/decorators/durable_app.py @@ -76,7 +76,7 @@ def decorator(entity_func): return decorator - def _configure_orchestrator_callable(self, wrap) -> Callable: + def _configure_orchestrator_callable(self, wrap, input_type=None) -> Callable: """Obtain decorator to construct an Orchestrator class from a user-defined Function. In the old programming model, this decorator's logic was unavoidable boilerplate @@ -86,6 +86,9 @@ def _configure_orchestrator_callable(self, wrap) -> Callable: ---------- wrap: Callable The next decorator to be applied. + input_type: Optional[type] + The expected type for orchestration input, forwarded from + the orchestration_trigger decorator. Returns ------- @@ -99,12 +102,16 @@ def decorator(orchestrator_func): # invoke next decorator, with the Orchestrator as input handle.__name__ = orchestrator_func.__name__ + # Stash the decorator-declared input type so the runtime + # can feed it to df_loads via context.get_input(). + handle._df_input_type = input_type return wrap(handle) return decorator def orchestration_trigger(self, context_name: str, - orchestration: Optional[str] = None): + orchestration: Optional[str] = None, + input_type: Optional[type] = None): """Register an Orchestrator Function. Parameters @@ -114,8 +121,13 @@ def orchestration_trigger(self, context_name: str, orchestration: Optional[str] Name of Orchestrator Function. The value is None by default, in which case the name of the method is used. + input_type: Optional[type] + The expected type for the orchestration input. When set, + ``context.get_input()`` will use this type to decode the + input payload without consulting ``sys.modules``. A + call-site ``expected_type`` argument on ``get_input`` + takes precedence over this value. """ - @self._configure_orchestrator_callable @self._configure_function_builder def wrap(fb): @@ -127,7 +139,7 @@ def decorator(): return decorator() - return wrap + return self._configure_orchestrator_callable(wrap, input_type=input_type) def activity_trigger(self, input_name: str, activity: Optional[str] = None): diff --git a/azure/durable_functions/models/DurableEntityContext.py b/azure/durable_functions/models/DurableEntityContext.py index 37cc980..43c8f32 100644 --- a/azure/durable_functions/models/DurableEntityContext.py +++ b/azure/durable_functions/models/DurableEntityContext.py @@ -1,5 +1,5 @@ from typing import Optional, Any, Dict, Tuple, List, Callable -from azure.functions._durable_functions import _deserialize_custom_object +from .utils.df_serialization import df_loads import json @@ -36,6 +36,7 @@ def __init__(self, self._is_newly_constructed: bool = False self._state: Any = state + self._state_is_raw: bool = False self._input: Any = None self._operation: Optional[str] = None self._result: Any = None @@ -109,10 +110,17 @@ def from_json(cls, json_str: str) -> Tuple['DurableEntityContext', List[Dict[str serialized_state = json_dict["state"] if serialized_state is not None: - json_dict["state"] = from_json_util(serialized_state) + # Keep the raw serialized form so get_state() can deserialize + # lazily with an expected_type supplied by the user. + json_dict["state"] = serialized_state + else: + json_dict["state"] = None batch = json_dict.pop("batch") - return cls(**json_dict), batch + ctx = cls(**json_dict) + if serialized_state is not None: + ctx._state_is_raw = True + return ctx, batch def set_state(self, state: Any) -> None: """Set the state of the entity. @@ -127,19 +135,26 @@ def set_state(self, state: Any) -> None: # should only serialize the state at the end of the batch self._state = state - def get_state(self, initializer: Optional[Callable[[], Any]] = None) -> Any: + def get_state(self, initializer: Optional[Callable[[], Any]] = None, + expected_type: Optional[type] = None) -> Any: """Get the current state of this entity. Parameters ---------- initializer: Optional[Callable[[], Any]] A 0-argument function to provide an initial state. Defaults to None. + expected_type: Optional[type] + The type to decode the state as. When set, the codec uses + this type directly without consulting ``sys.modules``. Returns ------- Any The current state of the entity """ + if self._state is not None and self._state_is_raw: + self._state = from_json_util(self._state, expected_type=expected_type) + self._state_is_raw = False state = self._state if state is not None: return state @@ -149,9 +164,15 @@ def get_state(self, initializer: Optional[Callable[[], Any]] = None) -> Any: state = initializer() return state - def get_input(self) -> Any: + def get_input(self, expected_type: Optional[type] = None) -> Any: """Get the input for this operation. + Parameters + ---------- + expected_type: Optional[type] + The type to decode the input as. When set, the codec uses + this type directly without consulting ``sys.modules``. + Returns ------- Any @@ -160,7 +181,7 @@ def get_input(self) -> Any: input_ = None req_input = self._input req_input = json.loads(req_input) - input_ = None if req_input is None else from_json_util(req_input) + input_ = None if req_input is None else df_loads(req_input, expected_type=expected_type) return input_ def set_result(self, result: Any) -> None: @@ -180,7 +201,7 @@ def destruct_on_exit(self) -> None: self._state = None -def from_json_util(json_str: str) -> Any: +def from_json_util(json_str: str, expected_type: Optional[type] = None) -> Any: """Load an arbitrary datatype from its JSON representation. The Out-of-proc SDK has a special JSON encoding strategy @@ -192,10 +213,12 @@ def from_json_util(json_str: str) -> Any: ---------- json_str: str A JSON-formatted string, from durable-extension + expected_type: Optional[type] + The type to decode the value as. Returns ------- Any: The original datatype that was serialized """ - return json.loads(json_str, object_hook=_deserialize_custom_object) + return df_loads(json_str, expected_type=expected_type) diff --git a/azure/durable_functions/models/DurableOrchestrationClient.py b/azure/durable_functions/models/DurableOrchestrationClient.py index 009001e..b6acc38 100644 --- a/azure/durable_functions/models/DurableOrchestrationClient.py +++ b/azure/durable_functions/models/DurableOrchestrationClient.py @@ -16,7 +16,7 @@ from ..models.DurableOrchestrationBindings import DurableOrchestrationBindings from .utils.http_utils import get_async_request, post_async_request, delete_async_request from .utils.entity_utils import EntityId -from azure.functions._durable_functions import _serialize_custom_object +from .utils.df_serialization import df_dumps class DurableOrchestrationClient: @@ -633,7 +633,7 @@ def _get_json_input(client_input: object) -> Optional[str]: If the JSON serialization failed, see `serialize_custom_object` """ if client_input is not None: - return json.dumps(client_input, default=_serialize_custom_object) + return df_dumps(client_input) return None @staticmethod diff --git a/azure/durable_functions/models/DurableOrchestrationContext.py b/azure/durable_functions/models/DurableOrchestrationContext.py index 531307c..d280336 100644 --- a/azure/durable_functions/models/DurableOrchestrationContext.py +++ b/azure/durable_functions/models/DurableOrchestrationContext.py @@ -34,7 +34,11 @@ from .actions import Action from ..models.TokenSource import TokenSource from .utils.entity_utils import EntityId -from azure.functions._durable_functions import _deserialize_custom_object +from .utils.df_serialization import df_loads +from .utils.type_discovery import ( + activity_output_type, + sub_orchestrator_output_type, +) from azure.durable_functions.constants import DATETIME_STRING_FORMAT from azure.durable_functions.decorators.metadata import OrchestrationTrigger, ActivityTrigger from azure.functions.decorators.function_app import FunctionBuilder @@ -167,7 +171,8 @@ def _set_is_replaying(self, is_replaying: bool): """ self._is_replaying = is_replaying - def call_activity(self, name: Union[str, Callable], input_: Optional[Any] = None) -> TaskBase: + def call_activity(self, name: Union[str, Callable], input_: Optional[Any] = None, + expected_type: Optional[type] = None) -> TaskBase: """Schedule an activity for execution. Parameters @@ -177,6 +182,10 @@ def call_activity(self, name: Union[str, Callable], input_: Optional[Any] = None in the Python V2 programming model, the activity function itself. input_: Optional[Any] The JSON-serializable input to pass to the activity function. + expected_type: Optional[type] + The type to decode the activity result as. Takes precedence + over the type discovered from the activity's return + annotation. Returns ------- @@ -191,16 +200,21 @@ def call_activity(self, name: Union[str, Callable], input_: Optional[Any] = None "decorator. Otherwise, provide in the name of the activity as a string." raise ValueError(error_message) + # Discover the activity's return type from its annotation, if any, + # so the result can be decoded without consulting sys.modules. + resolved_type = expected_type or activity_output_type(name) if isinstance(name, FunctionBuilder): name = self._get_function_name(name, ActivityTrigger) action = CallActivityAction(name, input_) task = self._generate_task(action) + task._expected_output_type = resolved_type return task def call_activity_with_retry(self, name: Union[str, Callable], retry_options: RetryOptions, - input_: Optional[Any] = None) -> TaskBase: + input_: Optional[Any] = None, + expected_type: Optional[type] = None) -> TaskBase: """Schedule an activity for execution with retry options. Parameters @@ -212,6 +226,10 @@ def call_activity_with_retry(self, The retry options for the activity function. input_: Optional[Any] The JSON-serializable input to pass to the activity function. + expected_type: Optional[type] + The type to decode the activity result as. Takes precedence + over the type discovered from the activity's return + annotation. Returns ------- @@ -227,11 +245,13 @@ def call_activity_with_retry(self, "decorator. Otherwise, provide in the name of the activity as a string." raise ValueError(error_message) + resolved_type = expected_type or activity_output_type(name) if isinstance(name, FunctionBuilder): name = self._get_function_name(name, ActivityTrigger) action = CallActivityWithRetryAction(name, retry_options, input_) task = self._generate_task(action, retry_options) + task._expected_output_type = resolved_type return task def call_http(self, method: str, uri: str, content: Optional[str] = None, @@ -288,7 +308,8 @@ def call_http(self, method: str, uri: str, content: Optional[str] = None, def call_sub_orchestrator(self, name: Union[str, Callable], input_: Optional[Any] = None, instance_id: Optional[str] = None, - version: Optional[str] = None) -> TaskBase: + version: Optional[str] = None, + expected_type: Optional[type] = None) -> TaskBase: """Schedule sub-orchestration function named `name` for execution. Parameters @@ -302,6 +323,10 @@ def call_sub_orchestrator(self, version: Optional[str] The version to assign to the sub-orchestration instance. If not specified, the defaultVersion from host.json will be used. + expected_type: Optional[type] + The type to decode the sub-orchestrator result as. Takes + precedence over the type discovered from the + sub-orchestrator's return annotation. Returns ------- @@ -316,18 +341,21 @@ def call_sub_orchestrator(self, "decorator. Otherwise, provide in the name of the activity as a string." raise ValueError(error_message) + resolved_type = expected_type or sub_orchestrator_output_type(name) if isinstance(name, FunctionBuilder): name = self._get_function_name(name, OrchestrationTrigger) action = CallSubOrchestratorAction(name, input_, instance_id, version) task = self._generate_task(action) + task._expected_output_type = resolved_type return task def call_sub_orchestrator_with_retry(self, name: Union[str, Callable], retry_options: RetryOptions, input_: Optional[Any] = None, instance_id: Optional[str] = None, - version: Optional[str] = None) -> TaskBase: + version: Optional[str] = None, + expected_type: Optional[type] = None) -> TaskBase: """Schedule sub-orchestration function named `name` for execution, with retry-options. Parameters @@ -343,6 +371,10 @@ def call_sub_orchestrator_with_retry(self, version: Optional[str] The version to assign to the sub-orchestration instance. If not specified, the defaultVersion from host.json will be used. + expected_type: Optional[type] + The type to decode the sub-orchestrator result as. Takes + precedence over the type discovered from the + sub-orchestrator's return annotation. Returns ------- @@ -357,18 +389,31 @@ def call_sub_orchestrator_with_retry(self, "decorator. Otherwise, provide in the name of the activity as a string." raise ValueError(error_message) + resolved_type = expected_type or sub_orchestrator_output_type(name) if isinstance(name, FunctionBuilder): name = self._get_function_name(name, OrchestrationTrigger) action = CallSubOrchestratorWithRetryAction( name, retry_options, input_, instance_id, version) task = self._generate_task(action, retry_options) + task._expected_output_type = resolved_type return task - def get_input(self) -> Optional[Any]: - """Get the orchestration input.""" - return None if self._input is None else json.loads(self._input, - object_hook=_deserialize_custom_object) + def get_input(self, expected_type: Optional[type] = None) -> Optional[Any]: + """Get the orchestration input. + + Parameters + ---------- + expected_type : Optional[type] + The type to decode the input as. Takes precedence over + the ``input_type`` declared on the orchestration trigger + decorator. When neither is set, decoding falls back to + module-only class resolution. + """ + if self._input is None: + return None + resolved = expected_type or getattr(self, "_input_expected_type", None) + return df_loads(self._input, expected_type=resolved) def new_uuid(self) -> str: """Create a new UUID that is safe for replay within an orchestration or operation. @@ -535,7 +580,8 @@ def function_context(self) -> FunctionContext: return self._function_context def call_entity(self, entityId: EntityId, - operationName: str, operationInput: Optional[Any] = None): + operationName: str, operationInput: Optional[Any] = None, + expected_type: Optional[type] = None): """Get the result of Durable Entity operation given some input. Parameters @@ -546,6 +592,10 @@ def call_entity(self, entityId: EntityId, The operation to execute operationInput: Optional[Any] The input for tne operation, defaults to None. + expected_type: Optional[type] + The type to decode the entity response as. When set, the + codec uses this type directly without consulting + ``sys.modules``. Returns ------- @@ -554,6 +604,7 @@ def call_entity(self, entityId: EntityId, """ action = CallEntityAction(entityId, operationName, operationInput) task = self._generate_task(action) + task._expected_output_type = expected_type return task def _record_fire_and_forget_action(self, action: Action): @@ -627,13 +678,18 @@ def create_timer(self, fire_at: datetime.datetime) -> TaskBase: task = self._generate_task(action, task_constructor=TimerTask) return task - def wait_for_external_event(self, name: str) -> TaskBase: + def wait_for_external_event(self, name: str, + expected_type: Optional[type] = None) -> TaskBase: """Wait asynchronously for an event to be raised with the name `name`. Parameters ---------- name : str The event name of the event that the task is waiting for. + expected_type : Optional[type] + The type to decode the event payload as. When set, the + codec uses this type directly without consulting + ``sys.modules``. Returns ------- @@ -642,6 +698,7 @@ def wait_for_external_event(self, name: str) -> TaskBase: """ action = WaitForExternalEventAction(name) task = self._generate_task(action, id_=name) + task._expected_output_type = expected_type return task def continue_as_new(self, input_: Any): diff --git a/azure/durable_functions/models/OrchestratorState.py b/azure/durable_functions/models/OrchestratorState.py index 7b42629..36fa2b2 100644 --- a/azure/durable_functions/models/OrchestratorState.py +++ b/azure/durable_functions/models/OrchestratorState.py @@ -4,8 +4,8 @@ from azure.durable_functions.models.ReplaySchema import ReplaySchema from .utils.json_utils import add_attrib +from .utils.df_serialization import _get_serialize_default from azure.durable_functions.models.actions.Action import Action -from azure.functions._durable_functions import _serialize_custom_object class OrchestratorState: @@ -114,4 +114,4 @@ def to_json_string(self) -> str: The instance of the object in json string format """ json_dict = self.to_json() - return json.dumps(json_dict, default=_serialize_custom_object) + return json.dumps(json_dict, default=_get_serialize_default()) diff --git a/azure/durable_functions/models/Task.py b/azure/durable_functions/models/Task.py index 7aa5b25..e566700 100644 --- a/azure/durable_functions/models/Task.py +++ b/azure/durable_functions/models/Task.py @@ -58,6 +58,13 @@ def __init__(self, id_: Union[int, str], actions: Union[List[Action], Action]): self.action_repr: Union[List[Action], Action] = actions self.is_played = False self._is_scheduled_flag = False + # The expected return type discovered from the user function's + # annotation, when the task was scheduled with a V2 FunctionBuilder. + # Forwarded to ``df_loads`` so custom objects can be decoded without + # touching ``sys.modules``/``importlib``. ``None`` means "no type + # info available" -- the codec then falls back to module lookup + # and, ultimately, the legacy decoder with a warning. + self._expected_output_type: Optional[type] = None @property def _is_scheduled(self) -> bool: diff --git a/azure/durable_functions/models/TaskOrchestrationExecutor.py b/azure/durable_functions/models/TaskOrchestrationExecutor.py index efe7adb..73fd63f 100644 --- a/azure/durable_functions/models/TaskOrchestrationExecutor.py +++ b/azure/durable_functions/models/TaskOrchestrationExecutor.py @@ -9,7 +9,7 @@ from collections import namedtuple import json from ..models.entities.ResponseMessage import ResponseMessage -from azure.functions._durable_functions import _deserialize_custom_object +from .utils.df_serialization import df_loads class TaskOrchestrationExecutor: @@ -181,18 +181,21 @@ def parse_history_event(directive_result): raise ValueError("EventType is not found in task object") # We provide the ability to deserialize custom objects, because the output of this - # will be passed directly to the orchestrator as the output of some activity + # will be passed directly to the orchestrator as the output of some activity. + # The expected type (when discoverable from the activity / sub-orchestrator's + # return annotation) lets ``df_loads`` decode custom classes without consulting + # ``sys.modules`` / ``importlib``. + expected_type = getattr(task, "_expected_output_type", None) if (event_type == HistoryEventType.SUB_ORCHESTRATION_INSTANCE_COMPLETED and directive_result.Result is not None): - return json.loads(directive_result.Result, object_hook=_deserialize_custom_object) + return df_loads(directive_result.Result, expected_type=expected_type) if (event_type == HistoryEventType.TASK_COMPLETED and directive_result.Result is not None): - return json.loads(directive_result.Result, object_hook=_deserialize_custom_object) + return df_loads(directive_result.Result, expected_type=expected_type) if (event_type == HistoryEventType.EVENT_RAISED and directive_result.Input is not None): # TODO: Investigate why the payload is in "Input" instead of "Result" - response = json.loads(directive_result.Input, - object_hook=_deserialize_custom_object) + response = df_loads(directive_result.Input, expected_type=expected_type) return response return None @@ -217,7 +220,8 @@ def parse_history_event(directive_result): new_value = parse_history_event(event) if task._api_name == "CallEntityAction": event_payload = ResponseMessage.from_dict(new_value) - new_value = json.loads(event_payload.result) + entity_expected = getattr(task, "_expected_output_type", None) + new_value = df_loads(event_payload.result, expected_type=entity_expected) if event_payload.is_exception: new_value = Exception(new_value) diff --git a/azure/durable_functions/models/actions/CallActivityAction.py b/azure/durable_functions/models/actions/CallActivityAction.py index 2e5c4ad..ea3fe7c 100644 --- a/azure/durable_functions/models/actions/CallActivityAction.py +++ b/azure/durable_functions/models/actions/CallActivityAction.py @@ -3,8 +3,7 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib -from json import dumps -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps class CallActivityAction(Action): @@ -16,7 +15,7 @@ class CallActivityAction(Action): def __init__(self, function_name: str, input_=None): self.function_name: str = function_name # It appears that `.input_` needs to be JSON-serializable at this point - self.input_ = dumps(input_, default=_serialize_custom_object) + self.input_ = df_dumps(input_) if not self.function_name: raise ValueError("function_name cannot be empty") diff --git a/azure/durable_functions/models/actions/CallActivityWithRetryAction.py b/azure/durable_functions/models/actions/CallActivityWithRetryAction.py index a6b3328..e21cda5 100644 --- a/azure/durable_functions/models/actions/CallActivityWithRetryAction.py +++ b/azure/durable_functions/models/actions/CallActivityWithRetryAction.py @@ -1,11 +1,10 @@ -from json import dumps from typing import Dict, Union from .Action import Action from .ActionType import ActionType from ..RetryOptions import RetryOptions from ..utils.json_utils import add_attrib, add_json_attrib -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps class CallActivityWithRetryAction(Action): @@ -18,7 +17,7 @@ def __init__(self, function_name: str, retry_options: RetryOptions, input_=None): self.function_name: str = function_name self.retry_options: RetryOptions = retry_options - self.input_ = dumps(input_, default=_serialize_custom_object) + self.input_ = df_dumps(input_) if not self.function_name: raise ValueError("function_name cannot be empty") diff --git a/azure/durable_functions/models/actions/CallEntityAction.py b/azure/durable_functions/models/actions/CallEntityAction.py index 55baa4e..894914a 100644 --- a/azure/durable_functions/models/actions/CallEntityAction.py +++ b/azure/durable_functions/models/actions/CallEntityAction.py @@ -3,8 +3,7 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib -from json import dumps -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps from ..utils.entity_utils import EntityId @@ -23,7 +22,7 @@ def __init__(self, entity_id: EntityId, operation: str, input_=None): self.instance_id: str = EntityId.get_scheduler_id(entity_id) self.operation: str = operation - self.input_: str = dumps(input_, default=_serialize_custom_object) + self.input_: str = df_dumps(input_) @property def action_type(self) -> int: diff --git a/azure/durable_functions/models/actions/CallSubOrchestratorAction.py b/azure/durable_functions/models/actions/CallSubOrchestratorAction.py index 03a2241..2925e45 100644 --- a/azure/durable_functions/models/actions/CallSubOrchestratorAction.py +++ b/azure/durable_functions/models/actions/CallSubOrchestratorAction.py @@ -3,8 +3,7 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib -from json import dumps -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps class CallSubOrchestratorAction(Action): @@ -13,7 +12,7 @@ class CallSubOrchestratorAction(Action): def __init__(self, function_name: str, _input: Optional[Any] = None, instance_id: Optional[str] = None, version: Optional[str] = None): self.function_name: str = function_name - self._input: str = dumps(_input, default=_serialize_custom_object) + self._input: str = df_dumps(_input) self.instance_id: Optional[str] = instance_id self.version: Optional[str] = version diff --git a/azure/durable_functions/models/actions/CallSubOrchestratorWithRetryAction.py b/azure/durable_functions/models/actions/CallSubOrchestratorWithRetryAction.py index c72d718..61c5bb7 100644 --- a/azure/durable_functions/models/actions/CallSubOrchestratorWithRetryAction.py +++ b/azure/durable_functions/models/actions/CallSubOrchestratorWithRetryAction.py @@ -3,9 +3,8 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib, add_json_attrib -from json import dumps from ..RetryOptions import RetryOptions -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps class CallSubOrchestratorWithRetryAction(Action): @@ -15,7 +14,7 @@ def __init__(self, function_name: str, retry_options: RetryOptions, _input: Optional[Any] = None, instance_id: Optional[str] = None, version: Optional[str] = None): self.function_name: str = function_name - self._input: str = dumps(_input, default=_serialize_custom_object) + self._input: str = df_dumps(_input) self.retry_options: RetryOptions = retry_options self.instance_id: Optional[str] = instance_id self.version: Optional[str] = version diff --git a/azure/durable_functions/models/actions/ContinueAsNewAction.py b/azure/durable_functions/models/actions/ContinueAsNewAction.py index 7af0508..4573566 100644 --- a/azure/durable_functions/models/actions/ContinueAsNewAction.py +++ b/azure/durable_functions/models/actions/ContinueAsNewAction.py @@ -3,8 +3,7 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib -from json import dumps -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps class ContinueAsNewAction(Action): @@ -15,7 +14,7 @@ class ContinueAsNewAction(Action): """ def __init__(self, input_=None): - self.input_ = dumps(input_, default=_serialize_custom_object) + self.input_ = df_dumps(input_) @property def action_type(self) -> int: diff --git a/azure/durable_functions/models/actions/SignalEntityAction.py b/azure/durable_functions/models/actions/SignalEntityAction.py index d6e9be5..d7ace9a 100644 --- a/azure/durable_functions/models/actions/SignalEntityAction.py +++ b/azure/durable_functions/models/actions/SignalEntityAction.py @@ -3,8 +3,7 @@ from .Action import Action from .ActionType import ActionType from ..utils.json_utils import add_attrib -from json import dumps -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps from ..utils.entity_utils import EntityId @@ -23,7 +22,7 @@ def __init__(self, entity_id: EntityId, operation: str, input_=None): self.instance_id: str = EntityId.get_scheduler_id(entity_id) self.operation: str = operation - self.input_: str = dumps(input_, default=_serialize_custom_object) + self.input_: str = df_dumps(input_) @property def action_type(self) -> int: diff --git a/azure/durable_functions/models/entities/EntityState.py b/azure/durable_functions/models/entities/EntityState.py index 13d22e7..1fabf6d 100644 --- a/azure/durable_functions/models/entities/EntityState.py +++ b/azure/durable_functions/models/entities/EntityState.py @@ -1,6 +1,6 @@ from typing import List, Optional, Dict, Any from .Signal import Signal -from azure.functions._durable_functions import _serialize_custom_object +from ..utils.df_serialization import df_dumps from .OperationResult import OperationResult import json @@ -56,7 +56,7 @@ def to_json(self) -> Dict[str, Any]: serialized_results = list(map(lambda x: x.to_json(), self.results)) json_dict["entityExists"] = self.entity_exists - json_dict["entityState"] = json.dumps(self.state, default=_serialize_custom_object) + json_dict["entityState"] = df_dumps(self.state) json_dict["results"] = serialized_results json_dict["signals"] = self.signals return json_dict diff --git a/azure/durable_functions/models/entities/OperationResult.py b/azure/durable_functions/models/entities/OperationResult.py index 05147f0..744dd28 100644 --- a/azure/durable_functions/models/entities/OperationResult.py +++ b/azure/durable_functions/models/entities/OperationResult.py @@ -1,6 +1,5 @@ from typing import Optional, Dict, Any -from azure.functions._durable_functions import _serialize_custom_object -import json +from ..utils.df_serialization import df_dumps class OperationResult: @@ -90,5 +89,5 @@ def to_json(self) -> Dict[str, Any]: to_json["isError"] = self.is_error to_json["duration"] = self.duration to_json["startTime"] = self.execution_start_time_ms - to_json["result"] = json.dumps(self.result, default=_serialize_custom_object) + to_json["result"] = df_dumps(self.result) return to_json diff --git a/azure/durable_functions/models/utils/df_serialization.py b/azure/durable_functions/models/utils/df_serialization.py new file mode 100644 index 0000000..31bae9b --- /dev/null +++ b/azure/durable_functions/models/utils/df_serialization.py @@ -0,0 +1,226 @@ +"""Centralized JSON serialization for Durable Functions payloads. + +This module wraps the legacy `json.dumps(value, default=_serialize_custom_object)` +/ `json.loads(s, object_hook=_deserialize_custom_object)` pipeline from +`azure.functions._durable_functions` behind `df_dumps` and `df_loads`. + +The wire format is **unchanged** -- builtins serialize to plain JSON and custom +objects use the `{"__class__": ..., "__module__": ..., "__data__": ...}` +convention that the Durable extension and downstream consumers already expect. + +`df_loads` adds an optional `expected_type` parameter that controls +type validation. Behavior depends on the typing mode: + +* **Loose mode** (default) -- the payload is inspected before + deserialization and a warning is logged on type mismatch, then the + legacy ``object_hook`` pipeline runs as usual. +* **Strict mode** -- ``import_module`` is never called on either side. + On encode, ``to_json`` is called on the top-level object only and + the result must be plain-JSON-serializable (nested custom objects + are **not** auto-encoded -- ``to_json`` must handle them). On + decode, ``expected_type.from_json`` is invoked directly with plain + JSON data. A ``TypeError`` is raised on type mismatch or if + ``expected_type`` is not provided for a custom-object payload. + Opt in by setting ``AZURE_FUNCTIONS_DURABLE_STRICT_TYPING`` to a + truthy value (``1``, ``true``, ``yes``).""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Any, Optional + +from azure.functions._durable_functions import ( + _deserialize_custom_object, + _serialize_custom_object, +) + +logger = logging.getLogger(__name__) + +_STRICT_ENV_VAR = "AZURE_FUNCTIONS_DURABLE_STRICT_TYPING" +_TRUTHY = frozenset({"1", "true", "yes"}) + + +def _is_strict_mode() -> bool: + return os.environ.get(_STRICT_ENV_VAR, "").strip().lower() in _TRUTHY + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def df_dumps(value: Any) -> str: + """Serialize *value* to a JSON string. + + In **loose mode** (default), custom objects are encoded recursively + via the legacy ``default=_serialize_custom_object`` handler — any + nested custom object is automatically wrapped in the + ``{"__class__", "__module__", "__data__"}`` envelope. + + In **strict mode**, the top-level custom object (if it has + ``to_json``) is wrapped in the legacy envelope, but the + ``__data__`` payload is serialized as **plain JSON** — no + ``default=`` hook fires. This means ``to_json()`` must return a + value that is natively JSON-serializable (dicts, lists, strings, + numbers, bools, None). A ``TypeError`` is raised at encode time + if any nested value is not serializable. + """ + if _is_strict_mode(): + if hasattr(value, "to_json"): + envelope = _serialize_custom_object(value) + return json.dumps(envelope) + # Primitive / plain-JSON value — serialize without default= + # so stray custom objects are caught immediately. + return json.dumps(value) + return json.dumps(value, default=_serialize_custom_object) + + +def df_loads(s: str, expected_type: Optional[type] = None) -> Any: + """Deserialize a JSON string, optionally validating the result type. + + Parameters + ---------- + s : str + The JSON-encoded payload. + expected_type : type, optional + When provided the raw JSON is parsed first (without triggering + ``import_module`` via the legacy ``object_hook``). If the + payload is a legacy custom-object dict its embedded class info + is validated against *expected_type* **before** any module is + imported. A matching *expected_type* is used to call + ``from_json`` directly, avoiding ``import_module`` entirely. + In loose mode a warning is emitted on mismatch; in strict mode + a ``TypeError`` is raised. + """ + if expected_type is not None: + return _loads_with_expected_type(s, expected_type) + + if _is_strict_mode(): + return _loads_strict_no_type(s) + + return json.loads(s, object_hook=_deserialize_custom_object) + + +def _loads_strict_no_type(s: str) -> Any: + """Strict-mode fallback when no *expected_type* is available. + + Parses without ``object_hook`` so ``import_module`` is never called. + If the top-level value is a legacy custom-object dict, raises + ``TypeError`` — the caller must supply an ``expected_type`` to + deserialize custom objects in strict mode. + """ + raw = json.loads(s) + if _is_legacy_custom_dict(raw): + raise TypeError( + "df_loads: strict mode requires expected_type to " + "deserialize custom-object payloads, but none was provided. " + f"Payload declares {raw['__module__']}.{raw['__class__']}." + ) + return raw + + +def _get_serialize_default(): + """Return the `default` callback for `json.dumps`. + + Use this in places that build their own `json.dumps` call (e.g. + `OrchestratorState.to_json_string`) rather than going through + `df_dumps`. + + In strict mode returns ``None`` — `OrchestratorState` fields are + already serialized via `df_dumps` so there should be no remaining + custom objects to encode. A stray custom object will raise + ``TypeError`` from ``json.dumps``, surfacing the problem early. + """ + if _is_strict_mode(): + return None + return _serialize_custom_object + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_LEGACY_KEYS = frozenset({"__class__", "__module__", "__data__"}) + + +def _is_legacy_custom_dict(d: Any) -> bool: + """Return True if *d* is a dict with legacy custom-object markers.""" + return isinstance(d, dict) and _LEGACY_KEYS.issubset(d) + + +def _loads_with_expected_type(s: str, expected_type: type) -> Any: + """Parse *s* and validate against *expected_type*. + + The raw JSON is parsed **without** the legacy ``object_hook`` so we + can inspect the payload before ``import_module`` fires. + + * **Strict mode** -- for custom-object payloads, calls + ``expected_type.from_json`` directly (no ``import_module``). For + primitives, validates then returns the plain value. Raises + ``TypeError`` on mismatch. + * **Loose mode** -- logs a warning on mismatch, then falls through + to the normal ``json.loads(s, object_hook=...)`` legacy path. + """ + raw = json.loads(s) + strict = _is_strict_mode() + + if _is_legacy_custom_dict(raw): + class_name = raw["__class__"] + module_name = raw["__module__"] + type_matches = (class_name == expected_type.__name__ + and module_name == expected_type.__module__) + + if not type_matches: + msg = ( + f"df_loads: payload declares class " + f"{module_name}.{class_name} but expected " + f"{expected_type.__module__}.{expected_type.__name__}" + ) + if strict: + raise TypeError(msg) + logger.warning(msg) + + if strict: + # Bypass import_module entirely — call from_json directly. + if not _has_json_protocol(expected_type): + raise TypeError( + f"df_loads: expected_type " + f"{expected_type.__module__}.{expected_type.__name__} " + f"does not expose from_json" + ) + return expected_type.from_json(raw["__data__"]) + + # Loose mode — legacy deserialization. + return json.loads(s, object_hook=_deserialize_custom_object) + + # Primitive / plain-JSON payload — validate the Python type. + if not _is_compatible(raw, expected_type): + msg = ( + f"df_loads: deserialized value ({type(raw).__name__}) is not " + f"compatible with expected type {expected_type}" + ) + if strict: + raise TypeError(msg) + logger.warning(msg) + + if strict: + return raw + # Loose mode — use legacy deserializer so nested custom objects + # (inside dicts/lists) are still reconstructed via object_hook. + return json.loads(s, object_hook=_deserialize_custom_object) + +def _has_json_protocol(cls: type) -> bool: + """Return True iff *cls* exposes callable `to_json` and `from_json`.""" + return callable(getattr(cls, "to_json", None)) and callable( + getattr(cls, "from_json", None) + ) + + +def _is_compatible(value: Any, expected_type: type) -> bool: + """Best-effort `isinstance` check that tolerates generic type hints.""" + try: + return isinstance(value, expected_type) + except TypeError: + # typing constructs like `List[int]` aren't valid for isinstance. + return True diff --git a/azure/durable_functions/models/utils/type_discovery.py b/azure/durable_functions/models/utils/type_discovery.py new file mode 100644 index 0000000..64da16c --- /dev/null +++ b/azure/durable_functions/models/utils/type_discovery.py @@ -0,0 +1,83 @@ +"""Best-effort type-hint discovery for Durable Functions call sites. + +These helpers feed the ``expected_type`` argument of +``df_serialization.df_loads`` so that custom-class instances can be +re-instantiated without consulting ``sys.modules`` / ``importlib``. + +All public helpers swallow exceptions and return ``None`` on failure -- +the caller treats ``None`` as "no type information available" and falls +back to module-only resolution (and, ultimately, the legacy decoder +with a warning). +""" + +from __future__ import annotations + +import inspect +import logging +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +def _unwrap_function_builder(name_or_callable: Any) -> Optional[Callable]: + """Return the underlying user function from a V2 ``FunctionBuilder``. + + Returns ``None`` for plain strings, plain callables, or anything we + don't recognize. + """ + # Avoid a hard dependency on the FunctionBuilder symbol (it lives in + # the azure-functions package and may move). + func = getattr(getattr(name_or_callable, "_function", None), "_func", None) + if callable(func): + return func + return None + + +def _return_annotation(fn: Callable) -> Optional[type]: + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + return None + ann = sig.return_annotation + if ann is inspect.Signature.empty: + return None + return ann if isinstance(ann, type) else None + + +def activity_output_type(name_or_callable: Any) -> Optional[type]: + """Discover the return-annotation type of a V2 activity function. + + Returns ``None`` if ``name_or_callable`` is a plain string (V1 model + or hand-written name) or if the annotation isn't a concrete type. + """ + fn = _unwrap_function_builder(name_or_callable) + if fn is None: + return None + return _return_annotation(fn) + + +def sub_orchestrator_output_type(name_or_callable: Any) -> Optional[type]: + """Discover the return-annotation type of a V2 sub-orchestrator function.""" + fn = _unwrap_function_builder(name_or_callable) + if fn is None: + return None + return _return_annotation(fn) + + +def entity_operation_input_type(entity_user_fn: Optional[Callable], + operation_name: str) -> Optional[type]: + """Best-effort discovery of an entity operation's input type. + + Entities in the V2 model are typically a single function that + dispatches on ``context.operation_name``. There is no general way to + statically associate an operation name with a parameter type; this + helper currently returns ``None`` for all such functions and exists + as the extension point for richer entity-dispatch patterns we may + add in the future (e.g. class-based entities with one method per + operation). + """ + if entity_user_fn is None or not operation_name: + return None + # Future work: inspect class-based entity dispatch tables. For now, + # signal "unknown" so the codec falls back to module-only resolution. + return None diff --git a/azure/durable_functions/orchestrator.py b/azure/durable_functions/orchestrator.py index 9e3a29b..3717cf3 100644 --- a/azure/durable_functions/orchestrator.py +++ b/azure/durable_functions/orchestrator.py @@ -66,7 +66,12 @@ def handle(context: func.OrchestrationContext) -> str: context_body = getattr(context, "body", None) if context_body is None: context_body = context - return Orchestrator(fn).handle(DurableOrchestrationContext.from_json(context_body)) + ctx = DurableOrchestrationContext.from_json(context_body) + # Propagate the decorator-declared input type (set by + # @app.orchestration_trigger(input_type=...)) so that + # context.get_input() can decode the payload type-safely. + ctx._input_expected_type = getattr(handle, "_df_input_type", None) + return Orchestrator(fn).handle(ctx) handle.orchestrator_function = fn diff --git a/tests/models/test_Decorators.py b/tests/models/test_Decorators.py index cf6d114..0c753f7 100644 --- a/tests/models/test_Decorators.py +++ b/tests/models/test_Decorators.py @@ -34,6 +34,22 @@ def dummy_function(my_context): ] }) +def test_orchestration_trigger_input_type_stashed(app): + """Verify that input_type= on the decorator is stashed on the handle.""" + + class MyInput: + pass + + @app.orchestration_trigger(context_name="my_context", input_type=MyInput) + def dummy_function(my_context): + pass + + user_code = get_user_code(app) + assert user_code.get_function_name() == "dummy_function" + # The input type is stashed on the inner callable (the Orchestrator + # handle) which lives at Function._func. + assert getattr(user_code._func, "_df_input_type", None) is MyInput + def test_activity_trigger(app): @app.activity_trigger(input_name="my_input") diff --git a/tests/models/test_DurableOrchestrationContext.py b/tests/models/test_DurableOrchestrationContext.py index 3aecae5..690837d 100644 --- a/tests/models/test_DurableOrchestrationContext.py +++ b/tests/models/test_DurableOrchestrationContext.py @@ -101,6 +101,77 @@ def test_get_input_json_str(): assert 'Seattle' == result['city'] + +class _Order: + """Test fixture for expected_type round-trips.""" + def __init__(self, item: str, qty: int): + self.item = item + self.qty = qty + + @staticmethod + def to_json(obj): + return {"item": obj.item, "qty": obj.qty} + + @staticmethod + def from_json(data): + return _Order(data["item"], data["qty"]) + + +def test_get_input_with_expected_type_kwarg(): + from azure.durable_functions.models.utils.df_serialization import df_dumps + builder = ContextBuilder('test_function_context') + builder.input_ = df_dumps(_Order("widget", 5)) + context = DurableOrchestrationContext.from_json(builder.to_json_string()) + + result = context.get_input(expected_type=_Order) + assert isinstance(result, _Order) + assert result.item == "widget" + assert result.qty == 5 + + +def test_get_input_with_decorator_input_type(): + from azure.durable_functions.models.utils.df_serialization import df_dumps + builder = ContextBuilder('test_function_context') + builder.input_ = df_dumps(_Order("widget", 5)) + context = DurableOrchestrationContext.from_json(builder.to_json_string()) + # Simulate what Orchestrator.create does when input_type is set + context._input_expected_type = _Order + + result = context.get_input() + assert isinstance(result, _Order) + assert result.item == "widget" + + +def test_get_input_kwarg_overrides_decorator_type(): + """Call-site expected_type takes precedence over decorator input_type.""" + from azure.durable_functions.models.utils.df_serialization import df_dumps + + class _Alt: + def __init__(self, item, qty): + self.item = item + self.qty = qty + + @staticmethod + def to_json(obj): + return {"item": obj.item, "qty": obj.qty} + + @staticmethod + def from_json(data): + return _Alt(data["item"], data["qty"]) + + builder = ContextBuilder('test_function_context') + builder.input_ = df_dumps(_Order("widget", 5)) + context = DurableOrchestrationContext.from_json(builder.to_json_string()) + context._input_expected_type = _Order # decorator says _Order + + # expected_type is used for pre-validation only; the legacy decoder + # still uses the payload's declared class. A warning is emitted + # because _Alt != _Order. + result = context.get_input(expected_type=_Alt) + assert isinstance(result, _Order) # legacy decoder uses payload class + assert result.item == "widget" + + def test_version_equals_version_from_execution_started_event(): builder = ContextBuilder('test_function_context') builder.history_events = [] diff --git a/tests/orchestrator/test_expected_type.py b/tests/orchestrator/test_expected_type.py new file mode 100644 index 0000000..58bcadb --- /dev/null +++ b/tests/orchestrator/test_expected_type.py @@ -0,0 +1,164 @@ +"""Tests for the expected_type kwarg on orchestration context APIs. + +Covers call_activity, call_sub_orchestrator, and their _with_retry variants +when an explicit expected_type is provided at the call site (V1 string-name +callers with no auto-discovery). +""" +import json +from datetime import datetime + +from tests.orchestrator.orchestrator_test_utils import ( + assert_orchestration_state_equals, + get_orchestration_state_result, +) +from tests.test_utils.ContextBuilder import ContextBuilder +from azure.durable_functions.models.OrchestratorState import OrchestratorState +from azure.durable_functions.models.actions.CallActivityAction import CallActivityAction +from azure.durable_functions.models.actions.CallSubOrchestratorAction import CallSubOrchestratorAction +from azure.durable_functions.models.RetryOptions import RetryOptions +from azure.durable_functions.models.utils.df_serialization import df_dumps + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _Order: + def __init__(self, item: str, qty: int = 1): + self.item = item + self.qty = qty + + @staticmethod + def to_json(obj): + return {"item": obj.item, "qty": obj.qty} + + @staticmethod + def from_json(data): + return _Order(data["item"], data["qty"]) + + +def _base_state(output=None) -> OrchestratorState: + return OrchestratorState(is_done=False, actions=[], output=output) + + +def _add_activity_completed(ctx_builder, id_, result_str, name="DoWork"): + ctx_builder.add_task_scheduled_event(name=name, id_=id_) + ctx_builder.add_orchestrator_completed_event() + ctx_builder.add_orchestrator_started_event() + ctx_builder.add_task_completed_event(id_=id_, result=result_str) + + +def _add_sub_orch_completed(ctx_builder, id_, result_str, name="SubOrch"): + ctx_builder.add_sub_orchestrator_started_event(name=name, id_=id_, input_="") + ctx_builder.add_orchestrator_completed_event() + ctx_builder.add_orchestrator_started_event() + ctx_builder.add_sub_orchestrator_completed_event(result=result_str, id_=id_) + + +# --------------------------------------------------------------------------- +# call_activity with expected_type +# --------------------------------------------------------------------------- + +def orchestrator_activity_expected_type(context): + result = yield context.call_activity("DoWork", "x", expected_type=_Order) + return result.item + + +def test_call_activity_with_expected_type(): + payload = df_dumps(_Order("widget", 5)) + ctx = ContextBuilder("test") + _add_activity_completed(ctx, 0, payload) + + result = get_orchestration_state_result(ctx, orchestrator_activity_expected_type) + + assert result["isDone"] is True + # The orchestrator returns result.item which is "widget" + assert result["output"] == "widget" + + +# --------------------------------------------------------------------------- +# call_activity_with_retry with expected_type +# --------------------------------------------------------------------------- + +def orchestrator_activity_retry_expected_type(context): + opts = RetryOptions(5000, 3) + result = yield context.call_activity_with_retry( + "DoWork", opts, "x", expected_type=_Order) + return result.item + + +def test_call_activity_with_retry_expected_type(): + payload = df_dumps(_Order("gadget", 2)) + ctx = ContextBuilder("test") + _add_activity_completed(ctx, 0, payload) + + result = get_orchestration_state_result(ctx, orchestrator_activity_retry_expected_type) + + assert result["isDone"] is True + assert result["output"] == "gadget" + + +# --------------------------------------------------------------------------- +# call_sub_orchestrator with expected_type +# --------------------------------------------------------------------------- + +def orchestrator_sub_orch_expected_type(context): + result = yield context.call_sub_orchestrator( + "SubOrch", "input", expected_type=_Order) + return result.item + + +def test_call_sub_orchestrator_with_expected_type(): + payload = df_dumps(_Order("part", 10)) + ctx = ContextBuilder("test") + _add_sub_orch_completed(ctx, 0, payload) + + result = get_orchestration_state_result(ctx, orchestrator_sub_orch_expected_type) + + assert result["isDone"] is True + assert result["output"] == "part" + + +# --------------------------------------------------------------------------- +# call_sub_orchestrator_with_retry with expected_type +# --------------------------------------------------------------------------- + +def orchestrator_sub_orch_retry_expected_type(context): + opts = RetryOptions(5000, 3) + result = yield context.call_sub_orchestrator_with_retry( + "SubOrch", opts, "input", expected_type=_Order) + return result.item + + +def test_call_sub_orchestrator_with_retry_expected_type(): + payload = df_dumps(_Order("gizmo", 3)) + ctx = ContextBuilder("test") + _add_sub_orch_completed(ctx, 0, payload) + + result = get_orchestration_state_result(ctx, orchestrator_sub_orch_retry_expected_type) + + assert result["isDone"] is True + assert result["output"] == "gizmo" + + +# --------------------------------------------------------------------------- +# expected_type kwarg overrides auto-discovered type (None in V1) +# --------------------------------------------------------------------------- + +def orchestrator_override(context): + """Call with string name (V1) + expected_type; auto-discovery returns None.""" + result = yield context.call_activity("DoWork", "x", expected_type=_Order) + return [result.item, result.qty] + + +def test_expected_type_kwarg_used_when_auto_discovery_returns_none(): + payload = df_dumps(_Order("bolt", 99)) + ctx = ContextBuilder("test") + _add_activity_completed(ctx, 0, payload) + + result = get_orchestration_state_result(ctx, orchestrator_override) + + assert result["isDone"] is True + output = result["output"] + assert output[0] == "bolt" + assert output[1] == 99 diff --git a/tests/orchestrator/test_external_event.py b/tests/orchestrator/test_external_event.py index 263ef77..86df610 100644 --- a/tests/orchestrator/test_external_event.py +++ b/tests/orchestrator/test_external_event.py @@ -3,6 +3,7 @@ from tests.orchestrator.orchestrator_test_utils import assert_orchestration_state_equals, get_orchestration_state_result from tests.test_utils.ContextBuilder import ContextBuilder from azure.durable_functions.models.actions.WaitForExternalEventAction import WaitForExternalEventAction +from azure.durable_functions.models.utils.df_serialization import df_dumps def generator_function(context): result = yield context.wait_for_external_event("A") @@ -51,4 +52,37 @@ def test_succeeds_on_out_of_order_payload(): expected_state.actions.append([WaitForExternalEventAction("B")]) expected_state._is_done = True expected = expected_state.to_json() - assert_orchestration_state_equals(expected, result) \ No newline at end of file + assert_orchestration_state_equals(expected, result) + + +class _Payload: + """Simple custom class for testing expected_type on external events.""" + def __init__(self, value: str): + self.value = value + + @staticmethod + def to_json(obj): + return {"value": obj.value} + + @staticmethod + def from_json(data): + return _Payload(data["value"]) + + +def generator_function_with_expected_type(context): + result = yield context.wait_for_external_event("A", expected_type=_Payload) + return result.value + + +def test_external_event_with_expected_type(): + """wait_for_external_event(expected_type=...) decodes custom objects.""" + timestamp = datetime.now() + json_input = df_dumps(_Payload("hello")) + context_builder = ContextBuilder() + context_builder.add_event_raised_event( + "A", input_=json_input, timestamp=timestamp, id_=-1) + result = get_orchestration_state_result( + context_builder, generator_function_with_expected_type) + + assert result["isDone"] is True + assert result["output"] == "hello" \ No newline at end of file diff --git a/tests/utils/test_df_serialization.py b/tests/utils/test_df_serialization.py new file mode 100644 index 0000000..6c1d692 --- /dev/null +++ b/tests/utils/test_df_serialization.py @@ -0,0 +1,656 @@ +"""Comprehensive round-trip and validation tests for df_serialization. + +Every data shape is tested in three configurations: + 1. No expected_type (legacy object_hook path) + 2. Loose mode + expected_type (warn on mismatch, legacy deserialize) + 3. Strict mode + expected_type (raise on mismatch, from_json directly) +""" + +import json +import logging +import os + +import pytest + +from azure.durable_functions.models.utils import df_serialization +from azure.durable_functions.models.utils.df_serialization import ( + df_dumps, + df_loads, + _get_serialize_default, + _STRICT_ENV_VAR, +) + + +# --------------------------------------------------------------------------- +# Helper classes +# --------------------------------------------------------------------------- + +class PlainPerson: + """Simple class: to_json returns a dict, from_json accepts a dict.""" + + def __init__(self, name: str, age: int): + self.name = name + self.age = age + + @staticmethod + def to_json(obj): + return {"name": obj.name, "age": obj.age} + + @staticmethod + def from_json(data): + return PlainPerson(data["name"], data["age"]) + + def __eq__(self, other): + return (isinstance(other, PlainPerson) + and self.name == other.name and self.age == other.age) + + +class ScalarPerson: + """to_json returns a scalar (str), not a dict.""" + + def __init__(self, name: str): + self.name = name + + @staticmethod + def to_json(obj): + return obj.name + + @staticmethod + def from_json(data): + return ScalarPerson(data) + + def __eq__(self, other): + return isinstance(other, ScalarPerson) and self.name == other.name + + +class Hat: + """Leaf object for nesting tests.""" + + def __init__(self, color: str): + self.color = color + + @staticmethod + def to_json(obj): + return {"color": obj.color} + + @staticmethod + def from_json(data): + return Hat(data["color"]) + + def __eq__(self, other): + return isinstance(other, Hat) and self.color == other.color + + +class NaiveOrder: + """Nested object whose from_json expects pre-constructed Hat instances. + + This relies on the bottom-up object_hook behavior — from_json receives + a Hat instance at data["hat"], not a raw dict. Works in loose mode but + fails in strict mode because strict skips object_hook. + """ + + def __init__(self, item: str, hat: Hat): + self.item = item + self.hat = hat + + @staticmethod + def to_json(obj): + return {"item": obj.item, "hat": obj.hat} + + @staticmethod + def from_json(data): + # Assumes data["hat"] is already a Hat instance (object_hook fired) + return NaiveOrder(data["item"], data["hat"]) + + def __eq__(self, other): + return (isinstance(other, NaiveOrder) + and self.item == other.item and self.hat == other.hat) + + +class SmartOrder: + """Nested object with strict-mode-compatible to_json / from_json. + + to_json produces plain JSON (calls Hat.to_json explicitly), so the + result is natively JSON-serializable without ``default=``. from_json + handles both the strict-mode shape (plain dict from to_json) and + the loose-mode shape (pre-constructed Hat or raw legacy dict). + """ + + def __init__(self, item: str, hat: Hat): + self.item = item + self.hat = hat + + @staticmethod + def to_json(obj): + return {"item": obj.item, "hat": Hat.to_json(obj.hat)} + + @staticmethod + def from_json(data): + hat_data = data["hat"] + if isinstance(hat_data, Hat): + # Loose mode: object_hook already constructed the Hat + hat = hat_data + else: + # Strict mode or plain dict: reconstruct from to_json output + hat = Hat.from_json(hat_data) + return SmartOrder(data["item"], hat) + + def __eq__(self, other): + return (isinstance(other, SmartOrder) + and self.item == other.item and self.hat == other.hat) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def strict(monkeypatch): + """Enable strict typing mode for the duration of a test.""" + monkeypatch.setenv(_STRICT_ENV_VAR, "1") + + +@pytest.fixture +def loose(monkeypatch): + """Explicitly disable strict typing mode.""" + monkeypatch.delenv(_STRICT_ENV_VAR, raising=False) + + +# =================================================================== +# 1. PRIMITIVES (str, int, float, bool, None, list, dict) +# =================================================================== + +@pytest.mark.parametrize("value", [ + None, + True, + False, + 0, + -1, + 42, + 3.14, + "", + "hello", + [], + [1, 2, 3], + [True, None, "mixed"], + {}, + {"a": 1, "b": [1, 2]}, + {"nested": {"deep": {"value": 7}}}, +]) +class TestPrimitiveRoundTrips: + """Primitives must round-trip identically in all three paths.""" + + def test_no_expected_type(self, value): + assert df_loads(df_dumps(value)) == value + + def test_loose_with_matching_type(self, value, loose, caplog): + # Use the actual type of the value as expected_type + et = type(value) if value is not None else type(None) + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + result = df_loads(df_dumps(value), expected_type=et) + assert result == value + + def test_strict_with_matching_type(self, value, strict): + et = type(value) if value is not None else type(None) + result = df_loads(df_dumps(value), expected_type=et) + assert result == value + + +# =================================================================== +# 2. SIMPLE CUSTOM OBJECTS (dict-returning to_json) +# =================================================================== + +class TestSimpleObject: + + def test_no_expected_type(self): + obj = PlainPerson("andy", 99) + decoded = df_loads(df_dumps(obj)) + assert decoded == obj + + def test_loose_matching_type(self, loose): + obj = PlainPerson("andy", 99) + decoded = df_loads(df_dumps(obj), expected_type=PlainPerson) + assert decoded == obj + + def test_strict_matching_type(self, strict): + obj = PlainPerson("andy", 99) + decoded = df_loads(df_dumps(obj), expected_type=PlainPerson) + assert decoded == obj + + def test_loose_mismatched_type_warns(self, loose, caplog): + encoded = df_dumps(PlainPerson("a", 1)) + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + decoded = df_loads(encoded, expected_type=ScalarPerson) + # Loose mode: legacy decoder uses the payload's class + assert isinstance(decoded, PlainPerson) + assert any("payload declares class" in r.message for r in caplog.records) + + def test_strict_mismatched_type_raises(self, strict): + encoded = df_dumps(PlainPerson("a", 1)) + with pytest.raises(TypeError, match="payload declares class"): + df_loads(encoded, expected_type=ScalarPerson) + + +# =================================================================== +# 3. SCALAR-RETURNING to_json +# =================================================================== + +class TestScalarToJson: + + def test_no_expected_type(self): + obj = ScalarPerson("andy") + decoded = df_loads(df_dumps(obj)) + assert decoded == obj + + def test_loose_matching_type(self, loose): + obj = ScalarPerson("andy") + decoded = df_loads(df_dumps(obj), expected_type=ScalarPerson) + assert decoded == obj + + def test_strict_matching_type(self, strict): + obj = ScalarPerson("andy") + decoded = df_loads(df_dumps(obj), expected_type=ScalarPerson) + assert decoded == obj + + +# =================================================================== +# 4. DICT WITH OBJECT PROPERTIES e.g. {"person": PlainPerson(...)} +# =================================================================== + +class TestDictWithObjectProperty: + """A plain dict containing a custom object as a value.""" + + def _make_payload(self): + return {"person": PlainPerson("a", 1), "count": 7} + + def test_no_expected_type(self): + """Loose path: object_hook reconstructs nested objects.""" + decoded = df_loads(df_dumps(self._make_payload())) + assert decoded["count"] == 7 + assert isinstance(decoded["person"], PlainPerson) + assert decoded["person"].name == "a" + + def test_loose_expected_dict(self, loose, caplog): + """Loose path + expected_type=dict: works, inner objects reconstructed.""" + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + decoded = df_loads(df_dumps(self._make_payload()), expected_type=dict) + assert isinstance(decoded["person"], PlainPerson) + # No warning — top-level is a dict matching expected_type + assert not any("not compatible" in r.message for r in caplog.records) + + def test_strict_encode_fails_for_nested_custom_objects(self, strict): + """Strict mode: a plain dict containing a custom object cannot be + encoded — json.dumps runs without default= so Hat raises TypeError.""" + with pytest.raises(TypeError): + df_dumps(self._make_payload()) + + +# =================================================================== +# 5. NESTED OBJECTS — "naive" from_json (expects pre-constructed) +# =================================================================== + +class TestNaiveNestedObject: + """NaiveOrder.from_json expects Hat to already be a Hat instance.""" + + def _make(self): + return NaiveOrder("widget", Hat("red")) + + def test_no_expected_type(self): + """Legacy path: object_hook fires bottom-up, Hat constructed first.""" + decoded = df_loads(df_dumps(self._make())) + assert isinstance(decoded, NaiveOrder) + assert isinstance(decoded.hat, Hat) + assert decoded.hat.color == "red" + + def test_loose_matching_type(self, loose): + """Loose + expected_type: legacy path still fires, nested works.""" + decoded = df_loads(df_dumps(self._make()), expected_type=NaiveOrder) + assert decoded == self._make() + + def test_strict_encode_fails_for_naive_to_json(self, strict): + """Strict mode: NaiveOrder.to_json returns a Hat instance, which + is not natively JSON-serializable. df_dumps should fail at encode.""" + with pytest.raises(TypeError): + df_dumps(self._make()) + + +# =================================================================== +# 6. NESTED OBJECTS — "smart" from_json (handles raw dicts) +# =================================================================== + +class TestSmartNestedObject: + """SmartOrder.from_json manually calls Hat.from_json when needed.""" + + def _make(self): + return SmartOrder("gadget", Hat("blue")) + + def test_no_expected_type(self): + decoded = df_loads(df_dumps(self._make())) + assert isinstance(decoded, SmartOrder) + assert decoded.hat == Hat("blue") + + def test_loose_matching_type(self, loose): + decoded = df_loads(df_dumps(self._make()), expected_type=SmartOrder) + assert decoded == self._make() + + def test_strict_matching_type(self, strict): + """Strict mode works: SmartOrder.from_json handles the raw dict.""" + decoded = df_loads(df_dumps(self._make()), expected_type=SmartOrder) + assert decoded == self._make() + assert isinstance(decoded.hat, Hat) + assert decoded.hat.color == "blue" + + +# =================================================================== +# 7. LIST OF OBJECTS +# =================================================================== + +class TestListOfObjects: + + def _make(self): + return [PlainPerson("a", 1), PlainPerson("b", 2)] + + def test_no_expected_type(self): + decoded = df_loads(df_dumps(self._make())) + assert len(decoded) == 2 + assert all(isinstance(p, PlainPerson) for p in decoded) + + def test_loose_expected_list(self, loose): + decoded = df_loads(df_dumps(self._make()), expected_type=list) + assert len(decoded) == 2 + assert all(isinstance(p, PlainPerson) for p in decoded) + + def test_strict_encode_fails_for_nested_custom_objects(self, strict): + """Strict mode: a list of custom objects cannot be encoded — the + list itself doesn't have to_json, and json.dumps runs without + default= so PlainPerson raises TypeError.""" + with pytest.raises(TypeError): + df_dumps(self._make()) + + +# =================================================================== +# 8. PRIMITIVE TYPE MISMATCHES +# =================================================================== + +class TestPrimitiveTypeMismatch: + + def test_loose_warns(self, loose, caplog): + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + result = df_loads(df_dumps("hello"), expected_type=int) + assert result == "hello" + assert any("not compatible" in r.message for r in caplog.records) + + def test_strict_raises(self, strict): + with pytest.raises(TypeError, match="not compatible with expected type"): + df_loads(df_dumps("hello"), expected_type=int) + + def test_loose_str_expected_dict_warns(self, loose, caplog): + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + result = df_loads(df_dumps("hello"), expected_type=dict) + assert result == "hello" + assert any("not compatible" in r.message for r in caplog.records) + + def test_strict_str_expected_dict_raises(self, strict): + with pytest.raises(TypeError): + df_loads(df_dumps("hello"), expected_type=dict) + + +# =================================================================== +# 9. typing CONSTRUCTS (List[int], Optional[str], etc.) +# =================================================================== + +class TestTypingConstructs: + """Generic type hints can't be validated with isinstance — we pass + through without error in both modes.""" + + def test_loose_list_of_int(self, loose): + from typing import List + decoded = df_loads(df_dumps([1, 2, 3]), expected_type=List[int]) + assert decoded == [1, 2, 3] + + def test_strict_list_of_int(self, strict): + from typing import List + decoded = df_loads(df_dumps([1, 2, 3]), expected_type=List[int]) + assert decoded == [1, 2, 3] + + def test_loose_optional_str(self, loose): + from typing import Optional + decoded = df_loads(df_dumps("hi"), expected_type=Optional[str]) + assert decoded == "hi" + + +# =================================================================== +# 10. STRICT MODE ENV VAR VALUES +# =================================================================== + +class TestStrictModeEnvVar: + + @pytest.mark.parametrize("val", ["1", "true", "yes", "TRUE", "Yes", " 1 "]) + def test_truthy_values_enable_strict(self, monkeypatch, val): + monkeypatch.setenv(_STRICT_ENV_VAR, val) + with pytest.raises(TypeError): + df_loads(df_dumps("hello"), expected_type=int) + + @pytest.mark.parametrize("val", ["0", "false", "no", "", "nope"]) + def test_non_truthy_values_stay_loose(self, monkeypatch, val, caplog): + monkeypatch.setenv(_STRICT_ENV_VAR, val) + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + result = df_loads(df_dumps("hello"), expected_type=int) + assert result == "hello" + + def test_unset_is_loose(self, monkeypatch): + monkeypatch.delenv(_STRICT_ENV_VAR, raising=False) + result = df_loads(df_dumps("hello"), expected_type=int) + assert result == "hello" + + +# =================================================================== +# 10b. STRICT MODE WITHOUT expected_type +# =================================================================== + +class TestStrictNoExpectedType: + """In strict mode, df_loads without expected_type must never call import_module.""" + + def test_primitive_returns_raw(self, strict): + assert df_loads(df_dumps(42)) == 42 + + def test_string_returns_raw(self, strict): + assert df_loads(df_dumps("hello")) == "hello" + + def test_none_returns_raw(self, strict): + assert df_loads(df_dumps(None)) is None + + def test_plain_dict_returns_raw(self, strict): + d = {"key": "value", "n": 1} + assert df_loads(df_dumps(d)) == d + + def test_plain_list_returns_raw(self, strict): + lst = [1, "two", None] + assert df_loads(df_dumps(lst)) == lst + + def test_custom_object_raises(self, strict): + s = df_dumps(PlainPerson("alice", 30)) + with pytest.raises(TypeError, match="strict mode requires expected_type"): + df_loads(s) + + def test_custom_object_error_includes_class(self, strict): + s = df_dumps(PlainPerson("alice", 30)) + with pytest.raises(TypeError, match="PlainPerson"): + df_loads(s) + + def test_loose_mode_custom_object_still_works(self, loose): + """Without strict, the legacy path runs even without expected_type.""" + p = PlainPerson("bob", 25) + result = df_loads(df_dumps(p)) + assert isinstance(result, PlainPerson) + assert result.name == "bob" + + +# =================================================================== +# 11. WIRE FORMAT VERIFICATION +# =================================================================== + +class TestWireFormat: + + def test_df_dumps_matches_legacy_json_dumps(self): + from azure.functions._durable_functions import _serialize_custom_object + value = {"key": "value", "list": [1, 2, 3]} + assert df_dumps(value) == json.dumps(value, default=_serialize_custom_object) + + def test_custom_object_produces_legacy_keys(self): + raw = json.loads(df_dumps(PlainPerson("andy", 99))) + assert raw == { + "__class__": "PlainPerson", + "__module__": __name__, + "__data__": {"name": "andy", "age": 99}, + } + + def test_scalar_to_json_produces_legacy_keys(self): + raw = json.loads(df_dumps(ScalarPerson("andy"))) + assert raw == { + "__class__": "ScalarPerson", + "__module__": __name__, + "__data__": "andy", + } + + def test_nested_object_produces_plain_json_data(self): + """SmartOrder.to_json serializes Hat explicitly, so __data__ + contains plain JSON — no nested legacy envelope.""" + raw = json.loads(df_dumps(SmartOrder("gadget", Hat("blue")))) + assert raw["__class__"] == "SmartOrder" + assert raw["__data__"] == {"item": "gadget", "hat": {"color": "blue"}} + + +# =================================================================== +# 12. _get_serialize_default +# =================================================================== + +class TestGetSerializeDefault: + + def test_returns_callable(self): + cb = _get_serialize_default() + assert callable(cb) + + def test_produces_legacy_dict(self): + cb = _get_serialize_default() + result = cb(PlainPerson("a", 1)) + assert result == { + "__class__": "PlainPerson", + "__module__": __name__, + "__data__": {"name": "a", "age": 1}, + } + + def test_strict_returns_none(self, strict): + cb = _get_serialize_default() + assert cb is None + + +# =================================================================== +# 13. ENCODE ERRORS +# =================================================================== + +class TestEncodeErrors: + + def test_class_without_to_json(self): + class NoProtocol: + pass + with pytest.raises(TypeError): + df_dumps(NoProtocol()) + + def test_set(self): + with pytest.raises(TypeError): + df_dumps({1, 2, 3}) + + def test_bytes(self): + with pytest.raises(TypeError): + df_dumps(b"hello") + + +# =================================================================== +# 13b. STRICT-MODE ENCODE +# =================================================================== + +class TestStrictEncode: + """In strict mode, df_dumps rejects non-serializable nested values.""" + + def test_primitive(self, strict): + assert df_dumps(42) == "42" + + def test_string(self, strict): + assert df_dumps("hello") == '"hello"' + + def test_plain_dict(self, strict): + assert json.loads(df_dumps({"a": 1})) == {"a": 1} + + def test_custom_object_top_level_ok(self, strict): + """Top-level custom object is wrapped in envelope.""" + raw = json.loads(df_dumps(PlainPerson("andy", 99))) + assert raw["__class__"] == "PlainPerson" + assert raw["__data__"] == {"name": "andy", "age": 99} + + def test_strict_smart_order_data_is_plain_json(self, strict): + """SmartOrder.to_json returns plain JSON, so encoding succeeds + and __data__ contains no nested envelopes.""" + raw = json.loads(df_dumps(SmartOrder("gadget", Hat("blue")))) + assert raw["__class__"] == "SmartOrder" + assert raw["__data__"] == {"item": "gadget", "hat": {"color": "blue"}} + + def test_strict_naive_order_fails(self, strict): + """NaiveOrder.to_json returns a Hat instance — not serializable.""" + with pytest.raises(TypeError): + df_dumps(NaiveOrder("widget", Hat("red"))) + + def test_strict_dict_with_custom_value_fails(self, strict): + """Plain dict containing a custom object — not serializable.""" + with pytest.raises(TypeError): + df_dumps({"person": PlainPerson("a", 1)}) + + def test_strict_list_with_custom_value_fails(self, strict): + """List containing custom objects — not serializable.""" + with pytest.raises(TypeError): + df_dumps([PlainPerson("a", 1)]) + + def test_loose_dict_with_custom_value_ok(self, loose): + """In loose mode, nested custom objects are still auto-wrapped.""" + raw = json.loads(df_dumps({"person": PlainPerson("a", 1)})) + assert raw["person"]["__class__"] == "PlainPerson" + + +# =================================================================== +# 14. EDGE CASES +# =================================================================== + +class TestEdgeCases: + + def test_bool_does_not_become_int(self): + """bool is a subclass of int — verify it stays bool.""" + out = df_loads(df_dumps(True)) + assert out is True + assert isinstance(out, bool) + + def test_none_with_expected_type_nonetype(self, loose): + assert df_loads(df_dumps(None), expected_type=type(None)) is None + + def test_none_with_expected_type_nonetype_strict(self, strict): + assert df_loads(df_dumps(None), expected_type=type(None)) is None + + def test_empty_dict_expected_dict(self, loose): + assert df_loads(df_dumps({}), expected_type=dict) == {} + + def test_empty_list_expected_list(self, strict): + assert df_loads(df_dumps([]), expected_type=list) == [] + + def test_tuple_becomes_list(self): + """Tuples serialize as JSON arrays — come back as lists.""" + assert df_loads(df_dumps((1, 2, 3))) == [1, 2, 3] + + def test_int_dict_keys_become_strings(self): + decoded = df_loads(df_dumps({1: "one", 2: "two"})) + assert decoded == {"1": "one", "2": "two"} + + def test_no_expected_type_no_warning(self, caplog): + """When expected_type is None, no warnings should fire.""" + with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): + df_loads(df_dumps(PlainPerson("a", 1))) + assert not any("not compatible" in r.message for r in caplog.records) + assert not any("payload declares" in r.message for r in caplog.records) diff --git a/tests/utils/test_type_discovery.py b/tests/utils/test_type_discovery.py new file mode 100644 index 0000000..a5b8390 --- /dev/null +++ b/tests/utils/test_type_discovery.py @@ -0,0 +1,81 @@ +"""Tests for type_discovery helpers.""" + +from typing import Optional +from unittest.mock import MagicMock + +from azure.durable_functions.models.utils.type_discovery import ( + activity_output_type, + sub_orchestrator_output_type, + entity_operation_input_type, +) + + +class _Result: + pass + + +def _make_function_builder(fn): + """Build a minimal stand-in for FunctionBuilder._function._func.""" + fb = MagicMock() + fb._function._func = fn + return fb + + +# --------------------------------------------------------------------------- +# activity_output_type +# --------------------------------------------------------------------------- + +def test_activity_output_type_returns_annotation(): + def my_activity(x) -> _Result: + return _Result() + fb = _make_function_builder(my_activity) + assert activity_output_type(fb) is _Result + + +def test_activity_output_type_returns_none_for_string(): + assert activity_output_type("activity_name") is None + + +def test_activity_output_type_returns_none_when_unannotated(): + def my_activity(x): + return None + fb = _make_function_builder(my_activity) + assert activity_output_type(fb) is None + + +def test_activity_output_type_returns_none_for_typing_construct(): + def my_activity(x) -> Optional[_Result]: + return None + fb = _make_function_builder(my_activity) + # Optional[_Result] is not a concrete class, so we return None. + assert activity_output_type(fb) is None + + +# --------------------------------------------------------------------------- +# sub_orchestrator_output_type (same shape as activity) +# --------------------------------------------------------------------------- + +def test_sub_orchestrator_output_type_returns_annotation(): + def my_sub_orch(ctx) -> _Result: + return _Result() + fb = _make_function_builder(my_sub_orch) + assert sub_orchestrator_output_type(fb) is _Result + + +def test_sub_orchestrator_output_type_returns_none_for_string(): + assert sub_orchestrator_output_type("orch_name") is None + + +# --------------------------------------------------------------------------- +# entity_operation_input_type (always None today) +# --------------------------------------------------------------------------- + +def test_entity_operation_input_type_returns_none(): + def my_entity(ctx): + pass + assert entity_operation_input_type(my_entity, "add") is None + + +def test_entity_operation_input_type_returns_none_for_missing_inputs(): + assert entity_operation_input_type(None, "add") is None + assert entity_operation_input_type(lambda ctx: None, "") is None From ab79e05657389ef1c5ebcc20838c5e5463f8d2b2 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Thu, 4 Jun 2026 10:49:46 -0600 Subject: [PATCH 2/4] Use Functions SDK shim, warn, test in pipelines --- .github/workflows/validate.yml | 21 +- CHANGELOG.md | 5 +- azure-pipelines.yml | 28 + .../models/utils/df_serialization.py | 276 ++------ eng/templates/build.yml | 27 +- tests/utils/test_df_serialization.py | 639 +++--------------- 6 files changed, 249 insertions(+), 747 deletions(-) diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index c1e81e7..9878cca 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -13,6 +13,22 @@ on: jobs: validate: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + # Fallback path: on Python 3.9 the SDK's df_dumps / df_loads cannot + # be installed (azure-functions 2.x requires >=3.13 and the 1.26.0 + # line requires >=3.10), so this leg exercises the legacy + # serialization fallback in df_serialization. + - python-version: "3.9" + functions-sdk: "" + # SDK path: Python 3.13 with the beta that first ships df_dumps / + # df_loads, exercising the SDK-delegated serialization branch. + # TODO: change to "azure-functions>=2.2.0" once 2.2.0 GA ships, and + # drop the explicit override step below. + - python-version: "3.13" + functions-sdk: "azure-functions>=2.2.0b5" steps: - name: Checkout repository uses: actions/checkout@v2 @@ -20,11 +36,14 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.9 + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install -r requirements.txt + - name: Install Functions SDK override + if: matrix.functions-sdk != '' + run: pip install "${{ matrix.functions-sdk }}" - name: Run Linter run: | cd azure diff --git a/CHANGELOG.md b/CHANGELOG.md index d9af378..0751fd6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,8 @@ All notable changes to this project will be documented in this file. ### Added - Client operation correlation logging: `FunctionInvocationId` is now propagated via HTTP headers to the host for client operations, enabling correlation with host logs. -- Centralized JSON serialization module (`azure.durable_functions.models.utils.df_serialization`): all serialization/deserialization of user payloads (orchestrator inputs/outputs, activity arguments and results, sub-orchestrator payloads, entity inputs/outputs, and client inputs) now flows through `df_dumps` / `df_loads`, replacing scattered `json.dumps(…, default=_serialize_custom_object)` / `json.loads(…, object_hook=_deserialize_custom_object)` calls. The wire format is **unchanged** — builtins serialize to plain JSON and custom objects continue to use the `{"__class__", "__module__", "__data__"}` convention. -- Type-hint-driven validation via `df_loads(s, expected_type=...)`: when the V2 programming model provides a return-type annotation for an activity or sub-orchestrator, `df_loads` validates the deserialized payload against that type **before** the legacy `object_hook` fires, catching class/module mismatches early. -- **Strict typing mode** (opt-in via `AZURE_FUNCTIONS_DURABLE_STRICT_TYPING=1`): when enabled, `import_module` is never called on either encode or decode. On encode, `df_dumps` wraps only the top-level custom object — `to_json()` must return plain-JSON-serializable data (nested custom objects must be serialized explicitly). On decode, `df_loads` calls `expected_type.from_json(raw["__data__"])` directly; `df_loads` without `expected_type` raises `TypeError` for custom-object payloads. A `TypeError` is also raised on type mismatch. +- Centralized JSON serialization module (`azure.durable_functions.models.utils.df_serialization`): all serialization/deserialization of user payloads (orchestrator inputs/outputs, activity arguments and results, sub-orchestrator payloads, entity inputs/outputs, and client inputs) now flows through `df_dumps` / `df_loads`, replacing scattered `json.dumps(…, default=_serialize_custom_object)` / `json.loads(…, object_hook=_deserialize_custom_object)` calls. This module is a thin shim over the Azure Functions SDK: when the installed `azure-functions` exposes `df_dumps` / `df_loads` (the centralized serializers with type-validation and strict-typing support), they are used directly so our serialization matches the SDK's `ActivityTriggerConverter` at the host boundary; otherwise it falls back to the legacy `_serialize_custom_object` / `_deserialize_custom_object` hooks, which keeps both sides symmetric. The wire format is **unchanged** — builtins serialize to plain JSON and custom objects continue to use the `{"__class__", "__module__", "__data__"}` convention. +- Type-hint-driven validation via `df_loads(s, expected_type=...)`: when the V2 programming model provides a return-type annotation for an activity or sub-orchestrator, the annotation is threaded through call sites so the SDK's `df_loads` can validate the deserialized payload against that type (when available). On older `azure-functions` releases the argument is accepted but ignored. - Return-type discovery for V2 decorated activities/sub-orchestrators (`azure.durable_functions.models.utils.type_discovery`): resolves the concrete return annotation from the user's registered function, used to supply `expected_type` to `df_loads`. ## 1.0.0b6 diff --git a/azure-pipelines.yml b/azure-pipelines.yml index bc20eef..e1cfc4e 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -65,5 +65,33 @@ stages: inputs: PathtoPublish: dist ArtifactName: $(componentArtifactName) + + - job: Test_Functions_Sdk_Path + displayName: Test SDK Serialization Path (Py 3.13) + # The Build_Durable_Functions job runs on Python 3.9, where the SDK's + # df_dumps / df_loads cannot be installed (azure-functions 2.x requires + # >=3.13), so it only exercises the legacy serialization fallback. This + # job runs on Python 3.13 with the beta that first ships df_dumps / + # df_loads to cover the SDK-delegated branch in df_serialization. + # TODO: change the override to 'azure-functions>=2.2.0' once 2.2.0 GA + # ships, and drop the explicit install step. + pool: + name: "1ES-Hosted-AzFunc" + demands: + - ImageOverride -equals MMSUbuntu20.04TLS + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.13' + - script: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install "azure-functions>=2.2.0b5" + workingDirectory: $(baseFolder) + displayName: 'Install dependencies (SDK serializers)' + - script: | + pip install pytest pytest-azurepipelines + pytest --ignore=samples-v2 + displayName: 'pytest' diff --git a/azure/durable_functions/models/utils/df_serialization.py b/azure/durable_functions/models/utils/df_serialization.py index 31bae9b..544b34c 100644 --- a/azure/durable_functions/models/utils/df_serialization.py +++ b/azure/durable_functions/models/utils/df_serialization.py @@ -1,34 +1,38 @@ """Centralized JSON serialization for Durable Functions payloads. -This module wraps the legacy `json.dumps(value, default=_serialize_custom_object)` -/ `json.loads(s, object_hook=_deserialize_custom_object)` pipeline from -`azure.functions._durable_functions` behind `df_dumps` and `df_loads`. - -The wire format is **unchanged** -- builtins serialize to plain JSON and custom -objects use the `{"__class__": ..., "__module__": ..., "__data__": ...}` -convention that the Durable extension and downstream consumers already expect. - -`df_loads` adds an optional `expected_type` parameter that controls -type validation. Behavior depends on the typing mode: - -* **Loose mode** (default) -- the payload is inspected before - deserialization and a warning is logged on type mismatch, then the - legacy ``object_hook`` pipeline runs as usual. -* **Strict mode** -- ``import_module`` is never called on either side. - On encode, ``to_json`` is called on the top-level object only and - the result must be plain-JSON-serializable (nested custom objects - are **not** auto-encoded -- ``to_json`` must handle them). On - decode, ``expected_type.from_json`` is invoked directly with plain - JSON data. A ``TypeError`` is raised on type mismatch or if - ``expected_type`` is not provided for a custom-object payload. - Opt in by setting ``AZURE_FUNCTIONS_DURABLE_STRICT_TYPING`` to a - truthy value (``1``, ``true``, ``yes``).""" +This module is a thin shim over the Azure Functions SDK serialization +helpers in ``azure.functions._durable_functions``. + +When the installed ``azure-functions`` package exposes ``df_dumps`` / +``df_loads`` (the centralized serializers with optional type validation +and strict-typing support), this module re-exports them directly so that +our serialization matches **exactly** what the SDK's +``ActivityTriggerConverter`` uses at the host boundary. + +When those symbols are **not** available (older ``azure-functions`` +releases), we fall back to the legacy plain pipeline -- +``json.dumps(value, default=_serialize_custom_object)`` / +``json.loads(s, object_hook=_deserialize_custom_object)`` -- which is the +same behavior the SDK converter uses in those versions. + +We deliberately do **not** substitute a richer local implementation on the +fallback path: if ``df_dumps`` / ``df_loads`` are not available from the +SDK, the SDK's ``ActivityTriggerConverter`` will not use them either, so +emulating the enhanced behavior locally would make our serialization +diverge from the converter that actually encodes and decodes activity +payloads. Using only the ``_serialize_custom_object`` / +``_deserialize_custom_object`` hooks -- which exist in every supported +``azure-functions`` release -- keeps both sides symmetric. + +The wire format is **unchanged** -- builtins serialize to plain JSON and +custom objects use the ``{"__class__", "__module__", "__data__"}`` +convention that the Durable extension and downstream consumers expect. +""" from __future__ import annotations import json -import logging -import os +import warnings from typing import Any, Optional from azure.functions._durable_functions import ( @@ -36,191 +40,53 @@ _serialize_custom_object, ) -logger = logging.getLogger(__name__) - -_STRICT_ENV_VAR = "AZURE_FUNCTIONS_DURABLE_STRICT_TYPING" -_TRUTHY = frozenset({"1", "true", "yes"}) - - -def _is_strict_mode() -> bool: - return os.environ.get(_STRICT_ENV_VAR, "").strip().lower() in _TRUTHY - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- +try: + # Preferred: the SDK's centralized serializers (type-validation and + # strict-typing aware). Available in azure-functions >= 2.2.0 (Python + # >= 3.13) and >= 1.26.0 (Python 3.10-3.12). + from azure.functions._durable_functions import ( # type: ignore + df_dumps, + df_loads, + ) +except ImportError: + warnings.warn( + "The installed 'azure-functions' package does not provide the " + "centralized 'df_dumps' / 'df_loads' serializers. Durable Functions " + "is falling back to the legacy serialization pipeline; the wire " + "format is unchanged, but payload type validation (the 'expected_type' " + "argument and strict typing mode) is unavailable. Upgrade to " + "azure-functions>=2.2.0 on Python>=3.13, or azure-functions>=1.26.0 " + "on Python 3.10-3.12, to enable type-validated serialization.", + stacklevel=2, + ) -def df_dumps(value: Any) -> str: - """Serialize *value* to a JSON string. + def df_dumps(value: Any) -> str: + """Serialize *value* to JSON via the legacy custom-object hook.""" + return json.dumps(value, default=_serialize_custom_object) - In **loose mode** (default), custom objects are encoded recursively - via the legacy ``default=_serialize_custom_object`` handler — any - nested custom object is automatically wrapped in the - ``{"__class__", "__module__", "__data__"}`` envelope. + def df_loads(s: str, expected_type: Optional[type] = None) -> Any: + """Deserialize *s* via the legacy custom-object hook. - In **strict mode**, the top-level custom object (if it has - ``to_json``) is wrapped in the legacy envelope, but the - ``__data__`` payload is serialized as **plain JSON** — no - ``default=`` hook fires. This means ``to_json()`` must return a - value that is natively JSON-serializable (dicts, lists, strings, - numbers, bools, None). A ``TypeError`` is raised at encode time - if any nested value is not serializable. - """ - if _is_strict_mode(): - if hasattr(value, "to_json"): - envelope = _serialize_custom_object(value) - return json.dumps(envelope) - # Primitive / plain-JSON value — serialize without default= - # so stray custom objects are caught immediately. - return json.dumps(value) - return json.dumps(value, default=_serialize_custom_object) - - -def df_loads(s: str, expected_type: Optional[type] = None) -> Any: - """Deserialize a JSON string, optionally validating the result type. - - Parameters - ---------- - s : str - The JSON-encoded payload. - expected_type : type, optional - When provided the raw JSON is parsed first (without triggering - ``import_module`` via the legacy ``object_hook``). If the - payload is a legacy custom-object dict its embedded class info - is validated against *expected_type* **before** any module is - imported. A matching *expected_type* is used to call - ``from_json`` directly, avoiding ``import_module`` entirely. - In loose mode a warning is emitted on mismatch; in strict mode - a ``TypeError`` is raised. - """ - if expected_type is not None: - return _loads_with_expected_type(s, expected_type) - - if _is_strict_mode(): - return _loads_strict_no_type(s) - - return json.loads(s, object_hook=_deserialize_custom_object) - - -def _loads_strict_no_type(s: str) -> Any: - """Strict-mode fallback when no *expected_type* is available. - - Parses without ``object_hook`` so ``import_module`` is never called. - If the top-level value is a legacy custom-object dict, raises - ``TypeError`` — the caller must supply an ``expected_type`` to - deserialize custom objects in strict mode. - """ - raw = json.loads(s) - if _is_legacy_custom_dict(raw): - raise TypeError( - "df_loads: strict mode requires expected_type to " - "deserialize custom-object payloads, but none was provided. " - f"Payload declares {raw['__module__']}.{raw['__class__']}." - ) - return raw - - -def _get_serialize_default(): - """Return the `default` callback for `json.dumps`. - - Use this in places that build their own `json.dumps` call (e.g. - `OrchestratorState.to_json_string`) rather than going through - `df_dumps`. - - In strict mode returns ``None`` — `OrchestratorState` fields are - already serialized via `df_dumps` so there should be no remaining - custom objects to encode. A stray custom object will raise - ``TypeError`` from ``json.dumps``, surfacing the problem early. - """ - if _is_strict_mode(): - return None - return _serialize_custom_object - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -_LEGACY_KEYS = frozenset({"__class__", "__module__", "__data__"}) - - -def _is_legacy_custom_dict(d: Any) -> bool: - """Return True if *d* is a dict with legacy custom-object markers.""" - return isinstance(d, dict) and _LEGACY_KEYS.issubset(d) - - -def _loads_with_expected_type(s: str, expected_type: type) -> Any: - """Parse *s* and validate against *expected_type*. - - The raw JSON is parsed **without** the legacy ``object_hook`` so we - can inspect the payload before ``import_module`` fires. - - * **Strict mode** -- for custom-object payloads, calls - ``expected_type.from_json`` directly (no ``import_module``). For - primitives, validates then returns the plain value. Raises - ``TypeError`` on mismatch. - * **Loose mode** -- logs a warning on mismatch, then falls through - to the normal ``json.loads(s, object_hook=...)`` legacy path. - """ - raw = json.loads(s) - strict = _is_strict_mode() - - if _is_legacy_custom_dict(raw): - class_name = raw["__class__"] - module_name = raw["__module__"] - type_matches = (class_name == expected_type.__name__ - and module_name == expected_type.__module__) - - if not type_matches: - msg = ( - f"df_loads: payload declares class " - f"{module_name}.{class_name} but expected " - f"{expected_type.__module__}.{expected_type.__name__}" - ) - if strict: - raise TypeError(msg) - logger.warning(msg) - - if strict: - # Bypass import_module entirely — call from_json directly. - if not _has_json_protocol(expected_type): - raise TypeError( - f"df_loads: expected_type " - f"{expected_type.__module__}.{expected_type.__name__} " - f"does not expose from_json" - ) - return expected_type.from_json(raw["__data__"]) - - # Loose mode — legacy deserialization. + ``expected_type`` is accepted for call-site compatibility but is + ignored on this fallback path; type validation is only performed + by the SDK's ``df_loads`` when it is available. + """ return json.loads(s, object_hook=_deserialize_custom_object) - # Primitive / plain-JSON payload — validate the Python type. - if not _is_compatible(raw, expected_type): - msg = ( - f"df_loads: deserialized value ({type(raw).__name__}) is not " - f"compatible with expected type {expected_type}" - ) - if strict: - raise TypeError(msg) - logger.warning(msg) - - if strict: - return raw - # Loose mode — use legacy deserializer so nested custom objects - # (inside dicts/lists) are still reconstructed via object_hook. - return json.loads(s, object_hook=_deserialize_custom_object) - -def _has_json_protocol(cls: type) -> bool: - """Return True iff *cls* exposes callable `to_json` and `from_json`.""" - return callable(getattr(cls, "to_json", None)) and callable( - getattr(cls, "from_json", None) + +try: + from azure.functions._durable_functions import ( # type: ignore + _get_serialize_default, ) +except ImportError: + def _get_serialize_default(): + """Return the ``default`` callback for a standalone ``json.dumps``. + + Used where code builds its own ``json.dumps`` call (e.g. + ``OrchestratorState.to_json_string``) rather than going through + ``df_dumps``. + """ + return _serialize_custom_object -def _is_compatible(value: Any, expected_type: type) -> bool: - """Best-effort `isinstance` check that tolerates generic type hints.""" - try: - return isinstance(value, expected_type) - except TypeError: - # typing constructs like `List[int]` aren't valid for isinstance. - return True +__all__ = ["df_dumps", "df_loads", "_get_serialize_default"] diff --git a/eng/templates/build.yml b/eng/templates/build.yml index 35ad078..fe2064e 100644 --- a/eng/templates/build.yml +++ b/eng/templates/build.yml @@ -39,4 +39,29 @@ jobs: inputs: SourceFolder: dist Contents: '**' - TargetFolder: $(Build.ArtifactStagingDirectory) \ No newline at end of file + TargetFolder: $(Build.ArtifactStagingDirectory) + + - job: Test_Functions_Sdk_Path + displayName: Test SDK Serialization Path (Py 3.13) + + # The Build job runs on Python 3.9, where the SDK's df_dumps / df_loads + # cannot be installed (azure-functions 2.x requires >=3.13), so it only + # exercises the legacy serialization fallback. This job runs on Python + # 3.13 with the beta that first ships df_dumps / df_loads to cover the + # SDK-delegated branch in df_serialization. + # TODO: change the override to 'azure-functions>=2.2.0' once 2.2.0 GA + # ships, and drop the explicit install step. + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.13' + - script: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install "azure-functions>=2.2.0b5" + workingDirectory: $(System.DefaultWorkingDirectory) + displayName: 'Install dependencies (SDK serializers)' + - script: | + pip install pytest pytest-azurepipelines + pytest --ignore=samples-v2 + displayName: 'pytest' \ No newline at end of file diff --git a/tests/utils/test_df_serialization.py b/tests/utils/test_df_serialization.py index 6c1d692..e45af8c 100644 --- a/tests/utils/test_df_serialization.py +++ b/tests/utils/test_df_serialization.py @@ -1,23 +1,26 @@ -"""Comprehensive round-trip and validation tests for df_serialization. +"""Tests for the df_serialization shim. -Every data shape is tested in three configurations: - 1. No expected_type (legacy object_hook path) - 2. Loose mode + expected_type (warn on mismatch, legacy deserialize) - 3. Strict mode + expected_type (raise on mismatch, from_json directly) +``df_serialization`` is a thin shim over the Azure Functions SDK +serializers in ``azure.functions._durable_functions``: + +* When the installed ``azure-functions`` exposes ``df_dumps`` / + ``df_loads``, this module re-exports them directly. +* Otherwise it falls back to the legacy plain pipeline + (``json.dumps(value, default=_serialize_custom_object)`` / + ``json.loads(s, object_hook=_deserialize_custom_object)``). + +The richer type-validation / strict-typing behavior lives in (and is +tested by) the SDK; these tests only assert the contract this shim is +responsible for: round-tripping payloads and preserving the wire format. """ import json -import logging -import os -import pytest - -from azure.durable_functions.models.utils import df_serialization +import azure.functions._durable_functions as _sdk from azure.durable_functions.models.utils.df_serialization import ( df_dumps, df_loads, _get_serialize_default, - _STRICT_ENV_VAR, ) @@ -45,24 +48,6 @@ def __eq__(self, other): and self.name == other.name and self.age == other.age) -class ScalarPerson: - """to_json returns a scalar (str), not a dict.""" - - def __init__(self, name: str): - self.name = name - - @staticmethod - def to_json(obj): - return obj.name - - @staticmethod - def from_json(data): - return ScalarPerson(data) - - def __eq__(self, other): - return isinstance(other, ScalarPerson) and self.name == other.name - - class Hat: """Leaf object for nesting tests.""" @@ -81,13 +66,8 @@ def __eq__(self, other): return isinstance(other, Hat) and self.color == other.color -class NaiveOrder: - """Nested object whose from_json expects pre-constructed Hat instances. - - This relies on the bottom-up object_hook behavior — from_json receives - a Hat instance at data["hat"], not a raw dict. Works in loose mode but - fails in strict mode because strict skips object_hook. - """ +class NestedOrder: + """Nested object relying on bottom-up object_hook reconstruction.""" def __init__(self, item: str, hat: Hat): self.item = item @@ -99,66 +79,19 @@ def to_json(obj): @staticmethod def from_json(data): - # Assumes data["hat"] is already a Hat instance (object_hook fired) - return NaiveOrder(data["item"], data["hat"]) + return NestedOrder(data["item"], data["hat"]) def __eq__(self, other): - return (isinstance(other, NaiveOrder) + return (isinstance(other, NestedOrder) and self.item == other.item and self.hat == other.hat) -class SmartOrder: - """Nested object with strict-mode-compatible to_json / from_json. - - to_json produces plain JSON (calls Hat.to_json explicitly), so the - result is natively JSON-serializable without ``default=``. from_json - handles both the strict-mode shape (plain dict from to_json) and - the loose-mode shape (pre-constructed Hat or raw legacy dict). - """ - - def __init__(self, item: str, hat: Hat): - self.item = item - self.hat = hat - - @staticmethod - def to_json(obj): - return {"item": obj.item, "hat": Hat.to_json(obj.hat)} - - @staticmethod - def from_json(data): - hat_data = data["hat"] - if isinstance(hat_data, Hat): - # Loose mode: object_hook already constructed the Hat - hat = hat_data - else: - # Strict mode or plain dict: reconstruct from to_json output - hat = Hat.from_json(hat_data) - return SmartOrder(data["item"], hat) - - def __eq__(self, other): - return (isinstance(other, SmartOrder) - and self.item == other.item and self.hat == other.hat) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - -@pytest.fixture -def strict(monkeypatch): - """Enable strict typing mode for the duration of a test.""" - monkeypatch.setenv(_STRICT_ENV_VAR, "1") - - -@pytest.fixture -def loose(monkeypatch): - """Explicitly disable strict typing mode.""" - monkeypatch.delenv(_STRICT_ENV_VAR, raising=False) +# =========================================================================== +# Primitive round-trips +# =========================================================================== +import pytest -# =================================================================== -# 1. PRIMITIVES (str, int, float, bool, None, list, dict) -# =================================================================== @pytest.mark.parametrize("value", [ None, @@ -177,480 +110,112 @@ def loose(monkeypatch): {"a": 1, "b": [1, 2]}, {"nested": {"deep": {"value": 7}}}, ]) -class TestPrimitiveRoundTrips: - """Primitives must round-trip identically in all three paths.""" - - def test_no_expected_type(self, value): - assert df_loads(df_dumps(value)) == value - - def test_loose_with_matching_type(self, value, loose, caplog): - # Use the actual type of the value as expected_type - et = type(value) if value is not None else type(None) - with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): - result = df_loads(df_dumps(value), expected_type=et) - assert result == value - - def test_strict_with_matching_type(self, value, strict): - et = type(value) if value is not None else type(None) - result = df_loads(df_dumps(value), expected_type=et) - assert result == value - - -# =================================================================== -# 2. SIMPLE CUSTOM OBJECTS (dict-returning to_json) -# =================================================================== - -class TestSimpleObject: - - def test_no_expected_type(self): - obj = PlainPerson("andy", 99) - decoded = df_loads(df_dumps(obj)) - assert decoded == obj - - def test_loose_matching_type(self, loose): - obj = PlainPerson("andy", 99) - decoded = df_loads(df_dumps(obj), expected_type=PlainPerson) - assert decoded == obj - - def test_strict_matching_type(self, strict): - obj = PlainPerson("andy", 99) - decoded = df_loads(df_dumps(obj), expected_type=PlainPerson) - assert decoded == obj - - def test_loose_mismatched_type_warns(self, loose, caplog): - encoded = df_dumps(PlainPerson("a", 1)) - with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): - decoded = df_loads(encoded, expected_type=ScalarPerson) - # Loose mode: legacy decoder uses the payload's class - assert isinstance(decoded, PlainPerson) - assert any("payload declares class" in r.message for r in caplog.records) - - def test_strict_mismatched_type_raises(self, strict): - encoded = df_dumps(PlainPerson("a", 1)) - with pytest.raises(TypeError, match="payload declares class"): - df_loads(encoded, expected_type=ScalarPerson) - - -# =================================================================== -# 3. SCALAR-RETURNING to_json -# =================================================================== - -class TestScalarToJson: - - def test_no_expected_type(self): - obj = ScalarPerson("andy") - decoded = df_loads(df_dumps(obj)) - assert decoded == obj - - def test_loose_matching_type(self, loose): - obj = ScalarPerson("andy") - decoded = df_loads(df_dumps(obj), expected_type=ScalarPerson) - assert decoded == obj - - def test_strict_matching_type(self, strict): - obj = ScalarPerson("andy") - decoded = df_loads(df_dumps(obj), expected_type=ScalarPerson) - assert decoded == obj - - -# =================================================================== -# 4. DICT WITH OBJECT PROPERTIES e.g. {"person": PlainPerson(...)} -# =================================================================== - -class TestDictWithObjectProperty: - """A plain dict containing a custom object as a value.""" - - def _make_payload(self): - return {"person": PlainPerson("a", 1), "count": 7} - - def test_no_expected_type(self): - """Loose path: object_hook reconstructs nested objects.""" - decoded = df_loads(df_dumps(self._make_payload())) - assert decoded["count"] == 7 - assert isinstance(decoded["person"], PlainPerson) - assert decoded["person"].name == "a" - - def test_loose_expected_dict(self, loose, caplog): - """Loose path + expected_type=dict: works, inner objects reconstructed.""" - with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): - decoded = df_loads(df_dumps(self._make_payload()), expected_type=dict) - assert isinstance(decoded["person"], PlainPerson) - # No warning — top-level is a dict matching expected_type - assert not any("not compatible" in r.message for r in caplog.records) - - def test_strict_encode_fails_for_nested_custom_objects(self, strict): - """Strict mode: a plain dict containing a custom object cannot be - encoded — json.dumps runs without default= so Hat raises TypeError.""" - with pytest.raises(TypeError): - df_dumps(self._make_payload()) - - -# =================================================================== -# 5. NESTED OBJECTS — "naive" from_json (expects pre-constructed) -# =================================================================== - -class TestNaiveNestedObject: - """NaiveOrder.from_json expects Hat to already be a Hat instance.""" - - def _make(self): - return NaiveOrder("widget", Hat("red")) - - def test_no_expected_type(self): - """Legacy path: object_hook fires bottom-up, Hat constructed first.""" - decoded = df_loads(df_dumps(self._make())) - assert isinstance(decoded, NaiveOrder) - assert isinstance(decoded.hat, Hat) - assert decoded.hat.color == "red" - - def test_loose_matching_type(self, loose): - """Loose + expected_type: legacy path still fires, nested works.""" - decoded = df_loads(df_dumps(self._make()), expected_type=NaiveOrder) - assert decoded == self._make() - - def test_strict_encode_fails_for_naive_to_json(self, strict): - """Strict mode: NaiveOrder.to_json returns a Hat instance, which - is not natively JSON-serializable. df_dumps should fail at encode.""" - with pytest.raises(TypeError): - df_dumps(self._make()) - - -# =================================================================== -# 6. NESTED OBJECTS — "smart" from_json (handles raw dicts) -# =================================================================== - -class TestSmartNestedObject: - """SmartOrder.from_json manually calls Hat.from_json when needed.""" - - def _make(self): - return SmartOrder("gadget", Hat("blue")) - - def test_no_expected_type(self): - decoded = df_loads(df_dumps(self._make())) - assert isinstance(decoded, SmartOrder) - assert decoded.hat == Hat("blue") - - def test_loose_matching_type(self, loose): - decoded = df_loads(df_dumps(self._make()), expected_type=SmartOrder) - assert decoded == self._make() - - def test_strict_matching_type(self, strict): - """Strict mode works: SmartOrder.from_json handles the raw dict.""" - decoded = df_loads(df_dumps(self._make()), expected_type=SmartOrder) - assert decoded == self._make() - assert isinstance(decoded.hat, Hat) - assert decoded.hat.color == "blue" - - -# =================================================================== -# 7. LIST OF OBJECTS -# =================================================================== - -class TestListOfObjects: - - def _make(self): - return [PlainPerson("a", 1), PlainPerson("b", 2)] - - def test_no_expected_type(self): - decoded = df_loads(df_dumps(self._make())) - assert len(decoded) == 2 - assert all(isinstance(p, PlainPerson) for p in decoded) - - def test_loose_expected_list(self, loose): - decoded = df_loads(df_dumps(self._make()), expected_type=list) - assert len(decoded) == 2 - assert all(isinstance(p, PlainPerson) for p in decoded) - - def test_strict_encode_fails_for_nested_custom_objects(self, strict): - """Strict mode: a list of custom objects cannot be encoded — the - list itself doesn't have to_json, and json.dumps runs without - default= so PlainPerson raises TypeError.""" - with pytest.raises(TypeError): - df_dumps(self._make()) - - -# =================================================================== -# 8. PRIMITIVE TYPE MISMATCHES -# =================================================================== - -class TestPrimitiveTypeMismatch: - - def test_loose_warns(self, loose, caplog): - with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): - result = df_loads(df_dumps("hello"), expected_type=int) - assert result == "hello" - assert any("not compatible" in r.message for r in caplog.records) - - def test_strict_raises(self, strict): - with pytest.raises(TypeError, match="not compatible with expected type"): - df_loads(df_dumps("hello"), expected_type=int) - - def test_loose_str_expected_dict_warns(self, loose, caplog): - with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): - result = df_loads(df_dumps("hello"), expected_type=dict) - assert result == "hello" - assert any("not compatible" in r.message for r in caplog.records) - - def test_strict_str_expected_dict_raises(self, strict): - with pytest.raises(TypeError): - df_loads(df_dumps("hello"), expected_type=dict) - - -# =================================================================== -# 9. typing CONSTRUCTS (List[int], Optional[str], etc.) -# =================================================================== - -class TestTypingConstructs: - """Generic type hints can't be validated with isinstance — we pass - through without error in both modes.""" - - def test_loose_list_of_int(self, loose): - from typing import List - decoded = df_loads(df_dumps([1, 2, 3]), expected_type=List[int]) - assert decoded == [1, 2, 3] +def test_primitive_round_trip(value): + assert df_loads(df_dumps(value)) == value - def test_strict_list_of_int(self, strict): - from typing import List - decoded = df_loads(df_dumps([1, 2, 3]), expected_type=List[int]) - assert decoded == [1, 2, 3] - def test_loose_optional_str(self, loose): - from typing import Optional - decoded = df_loads(df_dumps("hi"), expected_type=Optional[str]) - assert decoded == "hi" - - -# =================================================================== -# 10. STRICT MODE ENV VAR VALUES -# =================================================================== - -class TestStrictModeEnvVar: - - @pytest.mark.parametrize("val", ["1", "true", "yes", "TRUE", "Yes", " 1 "]) - def test_truthy_values_enable_strict(self, monkeypatch, val): - monkeypatch.setenv(_STRICT_ENV_VAR, val) - with pytest.raises(TypeError): - df_loads(df_dumps("hello"), expected_type=int) +# =========================================================================== +# Custom object round-trips (legacy object_hook reconstruction) +# =========================================================================== - @pytest.mark.parametrize("val", ["0", "false", "no", "", "nope"]) - def test_non_truthy_values_stay_loose(self, monkeypatch, val, caplog): - monkeypatch.setenv(_STRICT_ENV_VAR, val) - with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): - result = df_loads(df_dumps("hello"), expected_type=int) - assert result == "hello" - - def test_unset_is_loose(self, monkeypatch): - monkeypatch.delenv(_STRICT_ENV_VAR, raising=False) - result = df_loads(df_dumps("hello"), expected_type=int) - assert result == "hello" - - -# =================================================================== -# 10b. STRICT MODE WITHOUT expected_type -# =================================================================== - -class TestStrictNoExpectedType: - """In strict mode, df_loads without expected_type must never call import_module.""" - - def test_primitive_returns_raw(self, strict): - assert df_loads(df_dumps(42)) == 42 - - def test_string_returns_raw(self, strict): - assert df_loads(df_dumps("hello")) == "hello" - - def test_none_returns_raw(self, strict): - assert df_loads(df_dumps(None)) is None - - def test_plain_dict_returns_raw(self, strict): - d = {"key": "value", "n": 1} - assert df_loads(df_dumps(d)) == d - - def test_plain_list_returns_raw(self, strict): - lst = [1, "two", None] - assert df_loads(df_dumps(lst)) == lst +def test_simple_object_round_trip(): + obj = PlainPerson("andy", 99) + assert df_loads(df_dumps(obj)) == obj - def test_custom_object_raises(self, strict): - s = df_dumps(PlainPerson("alice", 30)) - with pytest.raises(TypeError, match="strict mode requires expected_type"): - df_loads(s) - def test_custom_object_error_includes_class(self, strict): - s = df_dumps(PlainPerson("alice", 30)) - with pytest.raises(TypeError, match="PlainPerson"): - df_loads(s) +def test_nested_object_round_trip(): + obj = NestedOrder("widget", Hat("red")) + decoded = df_loads(df_dumps(obj)) + assert decoded == obj + assert isinstance(decoded.hat, Hat) - def test_loose_mode_custom_object_still_works(self, loose): - """Without strict, the legacy path runs even without expected_type.""" - p = PlainPerson("bob", 25) - result = df_loads(df_dumps(p)) - assert isinstance(result, PlainPerson) - assert result.name == "bob" +def test_dict_with_object_property_round_trip(): + payload = {"person": PlainPerson("a", 1), "count": 7} + decoded = df_loads(df_dumps(payload)) + assert decoded["count"] == 7 + assert isinstance(decoded["person"], PlainPerson) + assert decoded["person"].name == "a" -# =================================================================== -# 11. WIRE FORMAT VERIFICATION -# =================================================================== -class TestWireFormat: +def test_list_of_objects_round_trip(): + payload = [PlainPerson("a", 1), PlainPerson("b", 2)] + decoded = df_loads(df_dumps(payload)) + assert len(decoded) == 2 + assert all(isinstance(p, PlainPerson) for p in decoded) - def test_df_dumps_matches_legacy_json_dumps(self): - from azure.functions._durable_functions import _serialize_custom_object - value = {"key": "value", "list": [1, 2, 3]} - assert df_dumps(value) == json.dumps(value, default=_serialize_custom_object) - - def test_custom_object_produces_legacy_keys(self): - raw = json.loads(df_dumps(PlainPerson("andy", 99))) - assert raw == { - "__class__": "PlainPerson", - "__module__": __name__, - "__data__": {"name": "andy", "age": 99}, - } - - def test_scalar_to_json_produces_legacy_keys(self): - raw = json.loads(df_dumps(ScalarPerson("andy"))) - assert raw == { - "__class__": "ScalarPerson", - "__module__": __name__, - "__data__": "andy", - } - - def test_nested_object_produces_plain_json_data(self): - """SmartOrder.to_json serializes Hat explicitly, so __data__ - contains plain JSON — no nested legacy envelope.""" - raw = json.loads(df_dumps(SmartOrder("gadget", Hat("blue")))) - assert raw["__class__"] == "SmartOrder" - assert raw["__data__"] == {"item": "gadget", "hat": {"color": "blue"}} - - -# =================================================================== -# 12. _get_serialize_default -# =================================================================== - -class TestGetSerializeDefault: - - def test_returns_callable(self): - cb = _get_serialize_default() - assert callable(cb) - def test_produces_legacy_dict(self): - cb = _get_serialize_default() - result = cb(PlainPerson("a", 1)) - assert result == { - "__class__": "PlainPerson", - "__module__": __name__, - "__data__": {"name": "a", "age": 1}, - } +def test_expected_type_is_accepted(): + """expected_type is part of the call signature; happy-path decoding + still reconstructs the object regardless of which impl is active.""" + obj = PlainPerson("andy", 99) + decoded = df_loads(df_dumps(obj), expected_type=PlainPerson) + assert decoded == obj - def test_strict_returns_none(self, strict): - cb = _get_serialize_default() - assert cb is None +# =========================================================================== +# Wire format verification +# =========================================================================== -# =================================================================== -# 13. ENCODE ERRORS -# =================================================================== +def test_primitive_wire_format_is_plain_json(): + assert df_dumps({"a": 1, "b": [1, 2]}) == json.dumps({"a": 1, "b": [1, 2]}) -class TestEncodeErrors: - def test_class_without_to_json(self): - class NoProtocol: - pass - with pytest.raises(TypeError): - df_dumps(NoProtocol()) - - def test_set(self): - with pytest.raises(TypeError): - df_dumps({1, 2, 3}) - - def test_bytes(self): - with pytest.raises(TypeError): - df_dumps(b"hello") - - -# =================================================================== -# 13b. STRICT-MODE ENCODE -# =================================================================== - -class TestStrictEncode: - """In strict mode, df_dumps rejects non-serializable nested values.""" - - def test_primitive(self, strict): - assert df_dumps(42) == "42" - - def test_string(self, strict): - assert df_dumps("hello") == '"hello"' - - def test_plain_dict(self, strict): - assert json.loads(df_dumps({"a": 1})) == {"a": 1} - - def test_custom_object_top_level_ok(self, strict): - """Top-level custom object is wrapped in envelope.""" - raw = json.loads(df_dumps(PlainPerson("andy", 99))) - assert raw["__class__"] == "PlainPerson" - assert raw["__data__"] == {"name": "andy", "age": 99} - - def test_strict_smart_order_data_is_plain_json(self, strict): - """SmartOrder.to_json returns plain JSON, so encoding succeeds - and __data__ contains no nested envelopes.""" - raw = json.loads(df_dumps(SmartOrder("gadget", Hat("blue")))) - assert raw["__class__"] == "SmartOrder" - assert raw["__data__"] == {"item": "gadget", "hat": {"color": "blue"}} - - def test_strict_naive_order_fails(self, strict): - """NaiveOrder.to_json returns a Hat instance — not serializable.""" - with pytest.raises(TypeError): - df_dumps(NaiveOrder("widget", Hat("red"))) - - def test_strict_dict_with_custom_value_fails(self, strict): - """Plain dict containing a custom object — not serializable.""" - with pytest.raises(TypeError): - df_dumps({"person": PlainPerson("a", 1)}) +def test_custom_object_wire_format_uses_legacy_envelope(): + raw = json.loads(df_dumps(PlainPerson("andy", 99))) + assert raw["__class__"] == "PlainPerson" + assert raw["__module__"] == PlainPerson.__module__ + assert raw["__data__"] == {"name": "andy", "age": 99} - def test_strict_list_with_custom_value_fails(self, strict): - """List containing custom objects — not serializable.""" - with pytest.raises(TypeError): - df_dumps([PlainPerson("a", 1)]) - def test_loose_dict_with_custom_value_ok(self, loose): - """In loose mode, nested custom objects are still auto-wrapped.""" - raw = json.loads(df_dumps({"person": PlainPerson("a", 1)})) - assert raw["person"]["__class__"] == "PlainPerson" +# =========================================================================== +# _get_serialize_default +# =========================================================================== +def test_get_serialize_default_is_usable_with_json_dumps(): + default = _get_serialize_default() + encoded = json.dumps(PlainPerson("andy", 99), default=default) + raw = json.loads(encoded) + assert raw["__class__"] == "PlainPerson" -# =================================================================== -# 14. EDGE CASES -# =================================================================== -class TestEdgeCases: +# =========================================================================== +# Shim wiring +# =========================================================================== - def test_bool_does_not_become_int(self): - """bool is a subclass of int — verify it stays bool.""" - out = df_loads(df_dumps(True)) - assert out is True - assert isinstance(out, bool) +def test_shim_prefers_sdk_serializers_when_available(): + """If the installed SDK exposes df_dumps/df_loads, the shim must + re-export the SDK objects rather than a local fallback.""" + if hasattr(_sdk, "df_dumps"): + assert df_dumps is _sdk.df_dumps + assert df_loads is _sdk.df_loads + else: + # Fallback path: local functions defined in the shim module. + assert df_dumps.__module__.endswith("df_serialization") + assert df_loads.__module__.endswith("df_serialization") - def test_none_with_expected_type_nonetype(self, loose): - assert df_loads(df_dumps(None), expected_type=type(None)) is None - def test_none_with_expected_type_nonetype_strict(self, strict): - assert df_loads(df_dumps(None), expected_type=type(None)) is None +def test_fallback_path_warns_at_import(): + """When the SDK lacks df_dumps/df_loads, importing the shim must emit a + single UserWarning prompting an upgrade. When the SDK provides them, no + such warning is emitted.""" + import importlib + import warnings - def test_empty_dict_expected_dict(self, loose): - assert df_loads(df_dumps({}), expected_type=dict) == {} + from azure.durable_functions.models.utils import df_serialization - def test_empty_list_expected_list(self, strict): - assert df_loads(df_dumps([]), expected_type=list) == [] + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + importlib.reload(df_serialization) - def test_tuple_becomes_list(self): - """Tuples serialize as JSON arrays — come back as lists.""" - assert df_loads(df_dumps((1, 2, 3))) == [1, 2, 3] + upgrade_warnings = [ + w for w in caught + if issubclass(w.category, UserWarning) + and "df_dumps" in str(w.message) + ] - def test_int_dict_keys_become_strings(self): - decoded = df_loads(df_dumps({1: "one", 2: "two"})) - assert decoded == {"1": "one", "2": "two"} + if hasattr(_sdk, "df_dumps"): + assert upgrade_warnings == [] + else: + assert len(upgrade_warnings) == 1 - def test_no_expected_type_no_warning(self, caplog): - """When expected_type is None, no warnings should fire.""" - with caplog.at_level(logging.WARNING, logger=df_serialization.__name__): - df_loads(df_dumps(PlainPerson("a", 1))) - assert not any("not compatible" in r.message for r in caplog.records) - assert not any("payload declares" in r.message for r in caplog.records) From 5372e4d0f802499eebe7e4ec49b0614cefdaf00b Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Thu, 4 Jun 2026 10:53:13 -0600 Subject: [PATCH 3/4] Skip lint on 3.13 --- .github/workflows/validate.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/validate.yml b/.github/workflows/validate.yml index 9878cca..a6d50db 100644 --- a/.github/workflows/validate.yml +++ b/.github/workflows/validate.yml @@ -45,6 +45,12 @@ jobs: if: matrix.functions-sdk != '' run: pip install "${{ matrix.functions-sdk }}" - name: Run Linter + # Lint only on the canonical Python version. On Python 3.12+, PEP 701 + # changed f-string tokenization so pycodestyle inspects tokens inside + # f-strings, producing false positives (e.g. the ':' in 'http://' or + # the indentation of multi-line f-string concatenations). Linting is + # environment-agnostic, so running it once on 3.9 is sufficient. + if: matrix.python-version == '3.9' run: | cd azure flake8 . --count --show-source --statistics From 9e3e4a8c9f2ea8466492903660fd1c6797b4b3e7 Mon Sep 17 00:00:00 2001 From: Andy Staples Date: Thu, 4 Jun 2026 12:04:34 -0600 Subject: [PATCH 4/4] PR Feedback --- .../models/DurableEntityContext.py | 19 +++++--- .../models/OrchestratorState.py | 5 +- .../models/utils/type_discovery.py | 33 +++++++++++-- eng/templates/build.yml | 2 +- tests/orchestrator/test_entity.py | 46 +++++++++++++++++++ tests/orchestrator/test_external_event.py | 2 +- 6 files changed, 91 insertions(+), 16 deletions(-) diff --git a/azure/durable_functions/models/DurableEntityContext.py b/azure/durable_functions/models/DurableEntityContext.py index 43c8f32..4f1c4e0 100644 --- a/azure/durable_functions/models/DurableEntityContext.py +++ b/azure/durable_functions/models/DurableEntityContext.py @@ -108,13 +108,9 @@ def from_json(cls, json_str: str) -> Tuple['DurableEntityContext', List[Dict[str json_dict["key"] = json_dict["self"]["key"] json_dict.pop("self") + # Keep the raw serialized state (a JSON string) so get_state() can + # deserialize lazily with an expected_type supplied by the user. serialized_state = json_dict["state"] - if serialized_state is not None: - # Keep the raw serialized form so get_state() can deserialize - # lazily with an expected_type supplied by the user. - json_dict["state"] = serialized_state - else: - json_dict["state"] = None batch = json_dict.pop("batch") ctx = cls(**json_dict) @@ -134,6 +130,10 @@ def set_state(self, state: Any) -> None: # should only serialize the state at the end of the batch self._state = state + # The new state is a live Python value, not the raw JSON string + # loaded from the payload. Clear the raw flag so a subsequent + # get_state() in the same batch does not try to re-decode it. + self._state_is_raw = False def get_state(self, initializer: Optional[Callable[[], Any]] = None, expected_type: Optional[type] = None) -> Any: @@ -145,7 +145,11 @@ def get_state(self, initializer: Optional[Callable[[], Any]] = None, A 0-argument function to provide an initial state. Defaults to None. expected_type: Optional[type] The type to decode the state as. When set, the codec uses - this type directly without consulting ``sys.modules``. + this type directly without consulting ``sys.modules``. Note that + the persisted state is decoded lazily on the **first** get_state + call within a batch; an ``expected_type`` supplied on a later + call (after the state has already been decoded or replaced via + set_state) has no effect. Returns ------- @@ -199,6 +203,7 @@ def destruct_on_exit(self) -> None: """Delete this entity after the operation completes.""" self._exists = False self._state = None + self._state_is_raw = False def from_json_util(json_str: str, expected_type: Optional[type] = None) -> Any: diff --git a/azure/durable_functions/models/OrchestratorState.py b/azure/durable_functions/models/OrchestratorState.py index 36fa2b2..f32dc63 100644 --- a/azure/durable_functions/models/OrchestratorState.py +++ b/azure/durable_functions/models/OrchestratorState.py @@ -1,10 +1,9 @@ -import json from typing import List, Any, Dict, Optional from azure.durable_functions.models.ReplaySchema import ReplaySchema from .utils.json_utils import add_attrib -from .utils.df_serialization import _get_serialize_default +from .utils.df_serialization import df_dumps from azure.durable_functions.models.actions.Action import Action @@ -114,4 +113,4 @@ def to_json_string(self) -> str: The instance of the object in json string format """ json_dict = self.to_json() - return json.dumps(json_dict, default=_get_serialize_default()) + return df_dumps(json_dict) diff --git a/azure/durable_functions/models/utils/type_discovery.py b/azure/durable_functions/models/utils/type_discovery.py index 64da16c..6f80402 100644 --- a/azure/durable_functions/models/utils/type_discovery.py +++ b/azure/durable_functions/models/utils/type_discovery.py @@ -12,8 +12,10 @@ from __future__ import annotations +import functools import inspect import logging +import typing from typing import Any, Callable, Optional logger = logging.getLogger(__name__) @@ -33,12 +35,35 @@ def _unwrap_function_builder(name_or_callable: Any) -> Optional[Callable]: return None +@functools.lru_cache(maxsize=None) def _return_annotation(fn: Callable) -> Optional[type]: + """Resolve *fn*'s return annotation to a concrete ``type``, or ``None``. + + ``typing.get_type_hints`` is tried first so that string annotations + (``from __future__ import annotations`` / PEP 563) are resolved to the + real object. Results are memoized per function because this runs on + every ``call_activity`` / ``call_sub_orchestrator`` (including replay). + + Limitation: generic aliases such as ``list[Order]`` or + ``Optional[Order]`` are not concrete ``type`` objects, so they resolve + to ``None`` and the caller falls back to module-only resolution. + """ + ann: Any = inspect.Signature.empty try: - sig = inspect.signature(fn) - except (TypeError, ValueError): - return None - ann = sig.return_annotation + hints = typing.get_type_hints(fn) + except Exception: + hints = None + if hints is not None and "return" in hints: + ann = hints["return"] + else: + # get_type_hints couldn't resolve (e.g. forward ref it can't see); + # fall back to the raw signature annotation. + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + return None + ann = sig.return_annotation + if ann is inspect.Signature.empty: return None return ann if isinstance(ann, type) else None diff --git a/eng/templates/build.yml b/eng/templates/build.yml index fe2064e..24205b9 100644 --- a/eng/templates/build.yml +++ b/eng/templates/build.yml @@ -64,4 +64,4 @@ jobs: - script: | pip install pytest pytest-azurepipelines pytest --ignore=samples-v2 - displayName: 'pytest' \ No newline at end of file + displayName: 'pytest' diff --git a/tests/orchestrator/test_entity.py b/tests/orchestrator/test_entity.py index ceed420..27d3bd9 100644 --- a/tests/orchestrator/test_entity.py +++ b/tests/orchestrator/test_entity.py @@ -94,6 +94,18 @@ def counter_entity_function_raises_exception(context): def counter_entity_function_raises_exception_with_pystein(context): raise Exception("boom!") +def set_then_get_entity(context): + """Entity that sets state (without first reading it) in one operation and + reads it in a later operation. Used to exercise set-then-get across a + batch when the entity already has persisted state. + """ + operation = context.operation_name + if operation == "set": + context.set_state(10) + context.set_result("set") + elif operation == "get": + context.set_result(context.get_state(lambda: 0)) + def test_entity_raises_exception(): # Create input batch batch = [] @@ -163,6 +175,40 @@ def test_entity_signal_then_call(): #assert_valid_schema(result) assert_entity_state_equals(expected, result) +def test_entity_set_then_get_with_preexisting_raw_state(): + """Regression test: an entity that already has persisted state must be + able to set_state in one operation and get_state in a later operation + within the same batch. + + ``from_json`` keeps the persisted state in its raw (undecoded) form and + marks it as raw so the first ``get_state`` can decode it lazily with a + user-supplied ``expected_type``. ``set_state`` replaces that raw value + with a live Python value, so it must clear the raw flag -- otherwise a + later ``get_state`` would try to re-decode an already-live value and the + operation would fail. + """ + # Pre-existing persisted state (single-encoded JSON string) is what makes + # from_json mark the loaded state as raw. + batch = [] + add_to_batch(batch, name="set") + add_to_batch(batch, name="get") + context_builder = EntityContextBuilder(batch=batch, state=json.dumps(5)) + + # Run the entity, get observed result + result = get_entity_state_result( + context_builder, + set_then_get_entity, + ) + + # Both operations should succeed; the "get" must observe the value set by + # the earlier "set" (10), not crash trying to re-decode it. + expected_state = entity_base_expected_state() + apply_operation(expected_state, result="set", state=10) + apply_operation(expected_state, result=10, state=10) + expected = expected_state.to_json() + + assert_entity_state_equals(expected, result) + def test_entity_signal_then_call_with_pystein(): """Tests that a simple counter entity outputs the correct value after a sequence of operations. Mostly just a sanity check. diff --git a/tests/orchestrator/test_external_event.py b/tests/orchestrator/test_external_event.py index 86df610..4e92c46 100644 --- a/tests/orchestrator/test_external_event.py +++ b/tests/orchestrator/test_external_event.py @@ -85,4 +85,4 @@ def test_external_event_with_expected_type(): context_builder, generator_function_with_expected_type) assert result["isDone"] is True - assert result["output"] == "hello" \ No newline at end of file + assert result["output"] == "hello"