From 74bcdec16143baf19c25b3901c8a127cf8ab2673 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Wed, 13 May 2026 13:16:59 +0300 Subject: [PATCH] feat: add lifecycle hooks --- nest/common/__init__.py | 7 + nest/common/interfaces.py | 28 +++ nest/core/pynest_application.py | 68 +++++- nest/core/pynest_container.py | 175 +++++++++++++- nest/core/pynest_factory.py | 25 ++ .../test_common/test_lifecycle_interfaces.py | 33 +++ tests/test_core/test_lifecycle_hooks.py | 224 ++++++++++++++++++ 7 files changed, 551 insertions(+), 9 deletions(-) create mode 100644 nest/common/interfaces.py create mode 100644 tests/test_common/test_lifecycle_interfaces.py create mode 100644 tests/test_core/test_lifecycle_hooks.py diff --git a/nest/common/__init__.py b/nest/common/__init__.py index 6deed60..6da2eb3 100644 --- a/nest/common/__init__.py +++ b/nest/common/__init__.py @@ -24,3 +24,10 @@ Res, createParamDecorator, ) +from nest.common.interfaces import ( + BeforeApplicationShutdown, + OnApplicationBootstrap, + OnApplicationShutdown, + OnModuleDestroy, + OnModuleInit, +) diff --git a/nest/common/interfaces.py b/nest/common/interfaces.py new file mode 100644 index 0000000..737f1e7 --- /dev/null +++ b/nest/common/interfaces.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from typing import Any, Optional, Protocol, runtime_checkable + + +@runtime_checkable +class OnModuleInit(Protocol): + def on_module_init(self) -> Any: ... + + +@runtime_checkable +class OnApplicationBootstrap(Protocol): + def on_application_bootstrap(self) -> Any: ... + + +@runtime_checkable +class BeforeApplicationShutdown(Protocol): + def before_application_shutdown(self, signal: Optional[str]) -> Any: ... + + +@runtime_checkable +class OnModuleDestroy(Protocol): + def on_module_destroy(self) -> Any: ... + + +@runtime_checkable +class OnApplicationShutdown(Protocol): + def on_application_shutdown(self, signal: Optional[str]) -> Any: ... diff --git a/nest/core/pynest_application.py b/nest/core/pynest_application.py index 707ec88..45f2295 100644 --- a/nest/core/pynest_application.py +++ b/nest/core/pynest_application.py @@ -1,7 +1,10 @@ from __future__ import annotations +import asyncio import inspect -from typing import Any +import signal as signal_module +from contextlib import asynccontextmanager +from typing import Any, Iterable, Optional from fastapi import FastAPI, Request from fastapi.responses import JSONResponse @@ -18,6 +21,9 @@ class PyNestApp: def __init__(self, container: PyNestContainer, http_server: FastAPI) -> None: self.container = container self.http_server = http_server + self._closed = False + self._closing = False + self._install_lifespan_shutdown() routes_resolver = RoutesResolver(self.container, self.http_server) routes_resolver.register_routes() @@ -33,6 +39,31 @@ def use(self, middleware: type, **options: Any) -> "PyNestApp": self.http_server.add_middleware(middleware, **options) return self + def enable_shutdown_hooks( + self, signals: Optional[Iterable[signal_module.Signals]] = None + ) -> "PyNestApp": + """Register process signal handlers that trigger graceful shutdown.""" + shutdown_signals = tuple( + signals or (signal_module.SIGTERM, signal_module.SIGINT) + ) + for shutdown_signal in shutdown_signals: + signal_module.signal( + shutdown_signal, self._make_signal_handler(shutdown_signal) + ) + return self + + async def close(self, signal: Optional[str] = None) -> None: + """Run graceful application shutdown lifecycle hooks once.""" + if self._closed or self._closing: + return + + self._closing = True + try: + await self.container.shutdown_lifecycle(signal) + self._closed = True + finally: + self._closing = False + def use_global_filters(self, *filters) -> "PyNestApp": """Register one or more exception filters that apply to every route. @@ -73,3 +104,38 @@ async def handler(request: Request, exc: Exception): return result self.http_server.add_exception_handler(exc_type, handler) + + def _make_signal_handler(self, shutdown_signal: signal_module.Signals): + def handler(signum, frame): + self._close_from_signal(self._signal_name(signum or shutdown_signal)) + + return handler + + def _close_from_signal(self, signal_name: str) -> None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + asyncio.run(self.close(signal_name)) + return + + loop.create_task(self.close(signal_name)) + + @staticmethod + def _signal_name(signum) -> str: + try: + return signal_module.Signals(signum).name + except ValueError: + return str(signum) + + def _install_lifespan_shutdown(self) -> None: + original_lifespan_context = self.http_server.router.lifespan_context + + @asynccontextmanager + async def lifespan_context(app: FastAPI): + async with original_lifespan_context(app) as state: + try: + yield state + finally: + await self.close() + + self.http_server.router.lifespan_context = lifespan_context diff --git a/nest/core/pynest_container.py b/nest/core/pynest_container.py index a7d09dc..120dc2c 100644 --- a/nest/core/pynest_container.py +++ b/nest/core/pynest_container.py @@ -5,12 +5,27 @@ from typing import Any, Dict, List, Optional, Type, Union from nest.common.exceptions import CircularDependencyException +from nest.common.interfaces import ( + BeforeApplicationShutdown, + OnApplicationBootstrap, + OnApplicationShutdown, + OnModuleDestroy, + OnModuleInit, +) from nest.common.module import CompiledModule, ModuleCompiler, ModuleTokenFactory from nest.common.provider import InjectionToken, ProviderDescriptor from nest.core.dependency_graph import DependencyGraph from nest.core.encapsulation import validate_module_encapsulation from nest.core.injector_module import build_injector, _to_key +_LIFECYCLE_METHOD_NAMES = ( + "on_module_init", + "on_application_bootstrap", + "before_application_shutdown", + "on_module_destroy", + "on_application_shutdown", +) + class ModuleRef: """Internal container representation of a registered module.""" @@ -37,6 +52,9 @@ def __init__(self) -> None: self._modules: Dict[str, ModuleRef] = {} self._all_descriptors: List[ProviderDescriptor] = [] self._controller_classes: List[Type] = [] + self._module_instances: Dict[str, Any] = {} + self._lifecycle_initialized = False + self._lifecycle_shutdown = False self._module_token_factory = ModuleTokenFactory() self._module_compiler = ModuleCompiler(self._module_token_factory) @@ -105,11 +123,73 @@ def clear(self) -> None: self._modules.clear() self._all_descriptors.clear() self._controller_classes.clear() + self._module_instances.clear() + self._lifecycle_initialized = False + self._lifecycle_shutdown = False + + async def initialize_lifecycle(self) -> None: + """Run module init and application bootstrap hooks once.""" + if self._injector is None: + raise RuntimeError( + "Container not built. Call container.build() before lifecycle hooks." + ) + if self._lifecycle_initialized: + return + + for module_ref in self._modules.values(): + await self._call_hooks( + self._get_module_lifecycle_instances(module_ref), + OnModuleInit, + "on_module_init", + ) + + await self._call_hooks( + self._get_all_lifecycle_instances(), + OnApplicationBootstrap, + "on_application_bootstrap", + ) + self._lifecycle_initialized = True + + async def shutdown_lifecycle(self, signal: Optional[str] = None) -> None: + """Run application shutdown hooks once in graceful shutdown order.""" + if self._injector is None: + raise RuntimeError( + "Container not built. Call container.build() before lifecycle hooks." + ) + if self._lifecycle_shutdown: + return + + modules = list(self._modules.values()) + for module_ref in reversed(modules): + await self._call_hooks( + self._get_module_lifecycle_instances(module_ref), + BeforeApplicationShutdown, + "before_application_shutdown", + signal, + ) + + for module_ref in reversed(modules): + await self._call_hooks( + self._get_module_lifecycle_instances(module_ref), + OnModuleDestroy, + "on_module_destroy", + ) + + for module_ref in reversed(modules): + await self._call_hooks( + self._get_module_lifecycle_instances(module_ref), + OnApplicationShutdown, + "on_application_shutdown", + signal, + ) + + self._lifecycle_shutdown = True # ── Internal ─────────────────────────────────────────────────────────────── def _make_controller_descriptors(self) -> List[ProviderDescriptor]: from nest.common.provider import Scope + return [ ProviderDescriptor(provide=cls, use_class=cls, scope=Scope.SINGLETON) for cls in self._controller_classes @@ -117,8 +197,6 @@ def _make_controller_descriptors(self) -> List[ProviderDescriptor]: def _validate_dependency_graph(self) -> None: """Build a DAG from all class providers and raise CircularDependencyException on cycles.""" - import sys - graph = DependencyGraph() # Build a name→class lookup from all registered providers so forward refs can be resolved @@ -162,9 +240,90 @@ def _validate_dependency_graph(self) -> None: cycles = graph.detect_cycles() if cycles: - chain = " → ".join( - getattr(n, "__name__", repr(n)) for n in cycles[0] - ) - raise CircularDependencyException( - f"Circular dependency detected: {chain}" - ) + chain = " → ".join(getattr(n, "__name__", repr(n)) for n in cycles[0]) + raise CircularDependencyException(f"Circular dependency detected: {chain}") + + def _get_all_lifecycle_instances(self) -> List[Any]: + instances: List[Any] = [] + seen: set[int] = set() + for module_ref in self._modules.values(): + for instance in self._get_module_lifecycle_instances(module_ref): + instance_id = id(instance) + if instance_id in seen: + continue + seen.add(instance_id) + instances.append(instance) + return instances + + def _get_module_lifecycle_instances(self, module_ref: ModuleRef) -> List[Any]: + instances: List[Any] = [] + seen: set[int] = set() + + for desc in module_ref.compiled.provider_descriptors: + instance = self.get(desc.provide) + instance_id = id(instance) + if instance_id in seen: + continue + seen.add(instance_id) + instances.append(instance) + + module_instance = self._get_module_instance(module_ref) + if module_instance is not None and id(module_instance) not in seen: + instances.append(module_instance) + + return instances + + def _get_module_instance(self, module_ref: ModuleRef) -> Optional[Any]: + if module_ref.token in self._module_instances: + return self._module_instances[module_ref.token] + + if not any( + callable(getattr(module_ref.metatype, name, None)) + for name in _LIFECYCLE_METHOD_NAMES + ): + return None + + instance = self._instantiate_module(module_ref.metatype) + self._module_instances[module_ref.token] = instance + return instance + + def _instantiate_module(self, module_class: Type) -> Any: + try: + signature = inspect.signature(module_class.__init__) + except (TypeError, ValueError): + return module_class() + + kwargs = {} + for param in list(signature.parameters.values())[1:]: + if param.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + continue + if param.annotation is not inspect.Parameter.empty: + kwargs[param.name] = self.get(param.annotation) + elif param.default is inspect.Parameter.empty: + raise RuntimeError( + f"Cannot instantiate module {module_class.__name__}: " + f"constructor parameter {param.name!r} has no type annotation" + ) + + return module_class(**kwargs) + + async def _call_hooks( + self, instances: List[Any], protocol: Type, method_name: str, *args: Any + ) -> None: + calls = [ + self._call_hook(instance, method_name, *args) + for instance in instances + if isinstance(instance, protocol) + ] + if calls: + import asyncio + + await asyncio.gather(*calls) + + async def _call_hook(self, instance: Any, method_name: str, *args: Any) -> None: + result = getattr(instance, method_name)(*args) + if inspect.isawaitable(result): + await result diff --git a/nest/core/pynest_factory.py b/nest/core/pynest_factory.py index 9058d74..a74b24f 100644 --- a/nest/core/pynest_factory.py +++ b/nest/core/pynest_factory.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio +import threading from abc import ABC, abstractmethod from typing import Type, TypeVar @@ -34,6 +36,7 @@ def create(main_module: Type[ModuleType], **kwargs) -> PyNestApp: container = PyNestContainer() container.add_module(main_module) container.build() + PyNestFactory._run_async(container.initialize_lifecycle()) http_server = FastAPI(**kwargs) return PyNestApp(container, http_server) @@ -41,3 +44,25 @@ def create(main_module: Type[ModuleType], **kwargs) -> PyNestApp: @staticmethod def _create_server(**kwargs) -> FastAPI: return FastAPI(**kwargs) + + @staticmethod + def _run_async(coro): + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + + result = {} + + def runner(): + try: + result["value"] = asyncio.run(coro) + except BaseException as exc: + result["error"] = exc + + thread = threading.Thread(target=runner) + thread.start() + thread.join() + if "error" in result: + raise result["error"] + return result.get("value") diff --git a/tests/test_common/test_lifecycle_interfaces.py b/tests/test_common/test_lifecycle_interfaces.py new file mode 100644 index 0000000..cedecd5 --- /dev/null +++ b/tests/test_common/test_lifecycle_interfaces.py @@ -0,0 +1,33 @@ +from nest.common.interfaces import ( + BeforeApplicationShutdown, + OnApplicationBootstrap, + OnApplicationShutdown, + OnModuleDestroy, + OnModuleInit, +) + + +def test_lifecycle_interfaces_are_runtime_checkable(): + class HookedProvider: + def on_module_init(self): + pass + + def on_application_bootstrap(self): + pass + + def before_application_shutdown(self, signal): + pass + + def on_module_destroy(self): + pass + + def on_application_shutdown(self, signal): + pass + + provider = HookedProvider() + + assert isinstance(provider, OnModuleInit) + assert isinstance(provider, OnApplicationBootstrap) + assert isinstance(provider, BeforeApplicationShutdown) + assert isinstance(provider, OnModuleDestroy) + assert isinstance(provider, OnApplicationShutdown) diff --git a/tests/test_core/test_lifecycle_hooks.py b/tests/test_core/test_lifecycle_hooks.py new file mode 100644 index 0000000..606794d --- /dev/null +++ b/tests/test_core/test_lifecycle_hooks.py @@ -0,0 +1,224 @@ +import asyncio +import signal + +from fastapi.testclient import TestClient + +from nest.common.interfaces import ( + BeforeApplicationShutdown, + OnApplicationBootstrap, + OnApplicationShutdown, + OnModuleDestroy, + OnModuleInit, +) +from nest.core import Injectable, Module, PyNestFactory + + +def test_factory_runs_boot_hooks_in_import_order_before_bootstrap(): + events = [] + + @Injectable + class LeafService(OnModuleInit): + def on_module_init(self): + events.append("leaf-service:module-init") + + @Module(providers=[LeafService]) + class LeafModule: + pass + + @Injectable + class RootService(OnModuleInit, OnApplicationBootstrap): + async def on_module_init(self): + await asyncio.sleep(0) + events.append("root-service:module-init") + + async def on_application_bootstrap(self): + await asyncio.sleep(0) + events.append("root-service:bootstrap") + + @Module(imports=[LeafModule], providers=[RootService]) + class RootModule(OnModuleInit, OnApplicationBootstrap): + def on_module_init(self): + events.append("root-module:module-init") + + def on_application_bootstrap(self): + events.append("root-module:bootstrap") + + PyNestFactory.create(RootModule) + + assert set(events) == { + "leaf-service:module-init", + "root-service:module-init", + "root-module:module-init", + "root-service:bootstrap", + "root-module:bootstrap", + } + assert events.index("leaf-service:module-init") < events.index( + "root-service:module-init" + ) + assert events.index("leaf-service:module-init") < events.index( + "root-module:module-init" + ) + + module_init_indices = [ + index for index, event in enumerate(events) if event.endswith(":module-init") + ] + bootstrap_indices = [ + index for index, event in enumerate(events) if event.endswith(":bootstrap") + ] + assert max(module_init_indices) < min(bootstrap_indices) + + +def test_app_close_runs_shutdown_hooks_in_phase_order_and_reverse_module_order(): + events = [] + + @Injectable + class LeafService( + BeforeApplicationShutdown, OnModuleDestroy, OnApplicationShutdown + ): + async def before_application_shutdown(self, signal): + await asyncio.sleep(0) + events.append(f"leaf-service:before:{signal}") + + def on_module_destroy(self): + events.append("leaf-service:destroy") + + def on_application_shutdown(self, signal): + events.append(f"leaf-service:shutdown:{signal}") + + @Module(providers=[LeafService]) + class LeafModule: + pass + + @Injectable + class RootService( + BeforeApplicationShutdown, OnModuleDestroy, OnApplicationShutdown + ): + def before_application_shutdown(self, signal): + events.append(f"root-service:before:{signal}") + + async def on_module_destroy(self): + await asyncio.sleep(0) + events.append("root-service:destroy") + + def on_application_shutdown(self, signal): + events.append(f"root-service:shutdown:{signal}") + + @Module(imports=[LeafModule], providers=[RootService]) + class RootModule(BeforeApplicationShutdown, OnModuleDestroy, OnApplicationShutdown): + def before_application_shutdown(self, signal): + events.append(f"root-module:before:{signal}") + + def on_module_destroy(self): + events.append("root-module:destroy") + + async def on_application_shutdown(self, signal): + await asyncio.sleep(0) + events.append(f"root-module:shutdown:{signal}") + + app = PyNestFactory.create(RootModule) + + asyncio.run(app.close("SIGINT")) + asyncio.run(app.close("SIGTERM")) + + assert set(events) == { + "root-service:before:SIGINT", + "root-module:before:SIGINT", + "leaf-service:before:SIGINT", + "root-service:destroy", + "root-module:destroy", + "leaf-service:destroy", + "root-service:shutdown:SIGINT", + "root-module:shutdown:SIGINT", + "leaf-service:shutdown:SIGINT", + } + assert all("SIGTERM" not in event for event in events) + + before_indices = [ + index for index, event in enumerate(events) if ":before:" in event + ] + destroy_indices = [ + index for index, event in enumerate(events) if event.endswith(":destroy") + ] + shutdown_indices = [ + index for index, event in enumerate(events) if ":shutdown:" in event + ] + assert max(before_indices) < min(destroy_indices) + assert max(destroy_indices) < min(shutdown_indices) + + for phase in ("before:SIGINT", "destroy", "shutdown:SIGINT"): + leaf_event = f"leaf-service:{phase}" + root_events = [ + f"root-service:{phase}", + f"root-module:{phase}", + ] + assert all( + events.index(root_event) < events.index(leaf_event) + for root_event in root_events + ) + + +def test_fastapi_lifespan_closes_database_style_provider(): + events = [] + + @Injectable + class DatabaseService(OnModuleInit, OnModuleDestroy): + def __init__(self): + self.connected = False + + async def on_module_init(self): + await asyncio.sleep(0) + self.connected = True + events.append("connect") + + async def on_module_destroy(self): + await asyncio.sleep(0) + self.connected = False + events.append("disconnect") + + @Module(providers=[DatabaseService]) + class DatabaseModule: + pass + + app = PyNestFactory.create(DatabaseModule) + database = app.container.get(DatabaseService) + + assert database.connected is True + assert events == ["connect"] + + with TestClient(app.get_server()): + assert database.connected is True + + assert database.connected is False + assert events == ["connect", "disconnect"] + + +def test_enable_shutdown_hooks_wires_sigint_and_sigterm_to_app_close(monkeypatch): + events = [] + + @Injectable + class CleanupService(OnModuleDestroy): + def on_module_destroy(self): + events.append("closed") + + @Module(providers=[CleanupService]) + class CleanupModule: + pass + + registered_handlers = {} + + def fake_signal(signal_number, handler): + registered_handlers[signal_number] = handler + + monkeypatch.setattr(signal, "signal", fake_signal) + + app = PyNestFactory.create(CleanupModule) + result = app.enable_shutdown_hooks() + + assert result is app + assert set(registered_handlers) == {signal.SIGTERM, signal.SIGINT} + + sigterm_handler = registered_handlers[signal.SIGTERM] + sigterm_handler(signal.SIGTERM, None) + sigterm_handler(signal.SIGTERM, None) + + assert events == ["closed"]