From 15b4d44097c670f7be8b982313b6e8bfde205a8f Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Thu, 21 May 2026 09:15:52 +0300 Subject: [PATCH 1/2] feat: add DatabaseModule root registration --- nest/core/database/__init__.py | 19 ++ nest/core/database/database_module.py | 266 ++++++++++++++++ nest/core/decorators/database.py | 6 +- nest/core/injector_module.py | 29 +- nest/core/pynest_container.py | 10 + pyproject.toml | 1 + tests/test_cli/test_orm_templates.py | 60 ++++ .../test_database/test_database_module.py | 301 ++++++++++++++++++ 8 files changed, 686 insertions(+), 6 deletions(-) create mode 100644 nest/core/database/database_module.py create mode 100644 tests/test_cli/test_orm_templates.py create mode 100644 tests/test_core/test_database/test_database_module.py diff --git a/nest/core/database/__init__.py b/nest/core/database/__init__.py index e69de29..c76ee3d 100644 --- a/nest/core/database/__init__.py +++ b/nest/core/database/__init__.py @@ -0,0 +1,19 @@ +from nest.core.database.database_module import ( + DATABASE_ENGINE, + DATABASE_OPTIONS, + DATABASE_SESSION_FACTORY, + DatabaseModule, + DatabaseOptions, + DatabaseService, +) +from nest.core.database.orm_provider import Base + +__all__ = [ + "Base", + "DATABASE_ENGINE", + "DATABASE_OPTIONS", + "DATABASE_SESSION_FACTORY", + "DatabaseModule", + "DatabaseOptions", + "DatabaseService", +] diff --git a/nest/core/database/database_module.py b/nest/core/database/database_module.py new file mode 100644 index 0000000..d1bf1aa --- /dev/null +++ b/nest/core/database/database_module.py @@ -0,0 +1,266 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager, contextmanager +from dataclasses import dataclass, field +from typing import Any, Dict, Generator, Optional, Type + +from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import Session, sessionmaker + +from nest.common.provider import InjectionToken +from nest.core.decorators.module import Module +from nest.core.database.orm_config import AsyncConfigFactory, ConfigFactory +from nest.core.database.orm_provider import Base + +DATABASE_OPTIONS = InjectionToken( + "DATABASE_OPTIONS", "Normalized DatabaseModule.for_root options" +) +DATABASE_ENGINE = InjectionToken("DATABASE_ENGINE", "SQLAlchemy engine") +DATABASE_SESSION_FACTORY = InjectionToken( + "DATABASE_SESSION_FACTORY", "SQLAlchemy session factory" +) + + +@dataclass(frozen=True) +class DatabaseOptions: + driver: str + config_params: Dict[str, Any] + async_mode: bool = False + engine_params: Dict[str, Any] = field(default_factory=dict) + session_params: Dict[str, Any] = field(default_factory=dict) + create_all: bool = False + base: Type[Any] = Base + + +class DatabaseService: + """Lifecycle-aware SQLAlchemy service registered by DatabaseModule.""" + + def __init__( + self, + options: DatabaseOptions, + engine: Any, + session_factory: Any, + ) -> None: + self.options = options + self.engine = engine + self.session_factory = session_factory + self.Base = options.base + + def on_module_init(self): + if not self.options.create_all: + return None + return self.create_all() + + def on_module_destroy(self): + result = self.engine.dispose() + return result + + def create_all(self): + if self.options.async_mode: + return self._create_all_async() + self.Base.metadata.create_all(bind=self.engine) + return None + + async def _create_all_async(self) -> None: + async with self.engine.begin() as conn: + await conn.run_sync(self.Base.metadata.create_all) + + def drop_all(self): + if self.options.async_mode: + return self._drop_all_async() + self.Base.metadata.drop_all(bind=self.engine) + return None + + async def _drop_all_async(self) -> None: + async with self.engine.begin() as conn: + await conn.run_sync(self.Base.metadata.drop_all) + + def session(self): + if self.options.async_mode: + return self._async_session() + return self._sync_session() + + def get_session(self): + return self.session() + + def get_db(self): + if self.options.async_mode: + return self._async_db() + return self._sync_db() + + @contextmanager + def _sync_session(self) -> Generator[Session, None, None]: + db = self.session_factory() + try: + yield db + except Exception: + db.rollback() + raise + finally: + db.close() + + def _sync_db(self) -> Session: + return self.session_factory() + + @asynccontextmanager + async def _async_session(self) -> AsyncSession: + db = self.session_factory() + try: + yield db + except Exception: + await db.rollback() + raise + finally: + await db.close() + + async def _async_db(self): + db = self.session_factory() + try: + yield db + finally: + await db.close() + + +def create_database_engine(options: DatabaseOptions): + config_factory = AsyncConfigFactory if options.async_mode else ConfigFactory + engine_factory = create_async_engine if options.async_mode else create_engine + config_class = config_factory(db_type=options.driver).get_config() + config_url = config_class(**options.config_params).get_engine_url() + return engine_factory(config_url, **options.engine_params) + + +def create_database_session_factory(options: DatabaseOptions, engine: Any): + if options.async_mode: + session_params = {"expire_on_commit": False, "class_": AsyncSession} + session_params.update(options.session_params) + return async_sessionmaker(engine, **session_params) + return sessionmaker(engine, **options.session_params) + + +def create_database_service( + options: DatabaseOptions, + engine: Any, + session_factory: Any, +) -> DatabaseService: + return DatabaseService(options, engine, session_factory) + + +@Module(imports=[], providers=[], exports=[]) +class DatabaseModule: + @classmethod + def for_root( + cls, + driver: str = "postgresql", + *, + database: Optional[str] = None, + db_name: Optional[str] = None, + config_params: Optional[Dict[str, Any]] = None, + host: Optional[str] = None, + user: Optional[str] = None, + password: Optional[str] = None, + port: Optional[int] = None, + async_mode: bool = False, + engine_params: Optional[Dict[str, Any]] = None, + session_params: Optional[Dict[str, Any]] = None, + create_all: bool = False, + base: Type[Any] = Base, + is_global: bool = True, + **extra_config: Any, + ): + normalized_config = _normalize_config_params( + config_params=config_params, + database=database, + db_name=db_name, + host=host, + user=user, + password=password, + port=port, + extra_config=extra_config, + ) + options = DatabaseOptions( + driver=driver, + config_params=normalized_config, + async_mode=async_mode, + engine_params=engine_params or {}, + session_params=session_params or {}, + create_all=create_all, + base=base, + ) + + providers = [ + {"provide": DATABASE_OPTIONS, "useValue": options}, + { + "provide": DATABASE_ENGINE, + "useFactory": create_database_engine, + "inject": [DATABASE_OPTIONS], + }, + { + "provide": DATABASE_SESSION_FACTORY, + "useFactory": create_database_session_factory, + "inject": [DATABASE_OPTIONS, DATABASE_ENGINE], + }, + { + "provide": DatabaseService, + "useFactory": create_database_service, + "inject": [ + DATABASE_OPTIONS, + DATABASE_ENGINE, + DATABASE_SESSION_FACTORY, + ], + }, + ] + + module_name = _configured_module_name(driver=driver, async_mode=async_mode) + configured_module = type(module_name, (cls,), {}) + setattr(configured_module, "__pynest_database_root__", True) + return Module( + imports=[], + providers=providers, + exports=[ + DATABASE_OPTIONS, + DATABASE_ENGINE, + DATABASE_SESSION_FACTORY, + DatabaseService, + ], + is_global=is_global, + )(configured_module) + + +def _normalize_config_params( + *, + config_params: Optional[Dict[str, Any]], + database: Optional[str], + db_name: Optional[str], + host: Optional[str], + user: Optional[str], + password: Optional[str], + port: Optional[int], + extra_config: Dict[str, Any], +) -> Dict[str, Any]: + normalized = dict(config_params or {}) + + database_name = db_name if db_name is not None else database + if database_name is not None and "db_name" not in normalized: + normalized["db_name"] = database_name + + for key, value in { + "host": host, + "user": user, + "password": password, + "port": port, + }.items(): + if value is not None and key not in normalized: + normalized[key] = value + + for key, value in extra_config.items(): + if value is not None and key not in normalized: + normalized[key] = value + + return normalized + + +def _configured_module_name(driver: str, async_mode: bool) -> str: + prefix = "Async" if async_mode else "" + normalized_driver = "".join(part.capitalize() for part in driver.split("_")) + return f"{prefix}{normalized_driver}DatabaseModule" diff --git a/nest/core/decorators/database.py b/nest/core/decorators/database.py index 5f7ea2c..1583dd2 100644 --- a/nest/core/decorators/database.py +++ b/nest/core/decorators/database.py @@ -10,8 +10,8 @@ def db_request_handler(func): """ Decorator that wraps ORM service methods with timing, logging, and HTTP error - conversion. Session lifecycle (open / commit / rollback / close) is the - responsibility of each service method — use config.get_session() there. + conversion. Session lifecycle (open / commit / rollback / close) is the + responsibility of each service method; use DatabaseService.session() there. """ def wrapper(self, *args, **kwargs): @@ -32,7 +32,7 @@ def wrapper(self, *args, **kwargs): def async_db_request_handler(func): """ Async version of db_request_handler. Session lifecycle is the caller's - responsibility (pass session via Depends or use config.get_session()). + responsibility; use DatabaseService.session() in the service method. """ async def wrapper(*args, **kwargs): diff --git a/nest/core/injector_module.py b/nest/core/injector_module.py index bf32b20..34b96c8 100644 --- a/nest/core/injector_module.py +++ b/nest/core/injector_module.py @@ -32,8 +32,7 @@ class PyNestInjectorModule(InjectorModule): def __init__(self, descriptors: List[ProviderDescriptor]) -> None: self._descriptors = [ - d for d in descriptors - if d.use_factory is None and d.use_existing is None + d for d in descriptors if d.use_factory is None and d.use_existing is None ] def configure(self, binder) -> None: @@ -59,8 +58,14 @@ def build_injector(descriptors: List[ProviderDescriptor]) -> Injector: from injector import InstanceProvider injector = Injector([PyNestInjectorModule(descriptors)]) + provider_counts = {} + last_provider_index = {} + for index, desc in enumerate(descriptors): + key = _to_key(desc.provide) + provider_counts[key] = provider_counts.get(key, 0) + 1 + last_provider_index[key] = index - for desc in descriptors: + for index, desc in enumerate(descriptors): key = _to_key(desc.provide) if desc.use_factory is not None: @@ -73,4 +78,22 @@ def build_injector(descriptors: List[ProviderDescriptor]) -> Injector: existing_instance = injector.get(_to_key(desc.use_existing)) injector.binder.bind(key, to=InstanceProvider(existing_instance)) + elif ( + desc.use_value is not None + and provider_counts[key] > 1 + and last_provider_index[key] == index + ): + injector.binder.bind(key, to=InstanceProvider(desc.use_value)) + + elif ( + desc.use_class is not None + and provider_counts[key] > 1 + and last_provider_index[key] == index + ): + injector.binder.bind( + key, + to=desc.use_class, + scope=_injector_scope(desc.scope), + ) + return injector diff --git a/nest/core/pynest_container.py b/nest/core/pynest_container.py index 120dc2c..69a57f8 100644 --- a/nest/core/pynest_container.py +++ b/nest/core/pynest_container.py @@ -57,6 +57,7 @@ def __init__(self) -> None: self._lifecycle_shutdown = False self._module_token_factory = ModuleTokenFactory() self._module_compiler = ModuleCompiler(self._module_token_factory) + self._database_root_registered = False # ── Public API ───────────────────────────────────────────────────────────── @@ -74,6 +75,14 @@ def module_compiler(self): def add_module(self, module_class: Type) -> dict: """Compile and register a module and all its imports recursively.""" + if getattr(module_class, "__pynest_database_root__", False): + if self._database_root_registered: + raise RuntimeError( + "Only one DatabaseModule.for_root() can be registered per " + "application. Named database connections are not supported yet." + ) + self._database_root_registered = True + compiled = self._module_compiler.compile(module_class) token = compiled.token @@ -126,6 +135,7 @@ def clear(self) -> None: self._module_instances.clear() self._lifecycle_initialized = False self._lifecycle_shutdown = False + self._database_root_registered = False async def initialize_lifecycle(self) -> None: """Run module init and application bootstrap hooks once.""" diff --git a/pyproject.toml b/pyproject.toml index 4b8da1f..1035b8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ test = [ "beanie>=1.27.0,<2.0.0", "python-dotenv>=1.0.1,<2.0.0", "aiosqlite>=0.19.0,<1.0.0", + "greenlet>=3.1.1,<4.0.0", "websockets>=13.0,<16.0", ] docs = [ diff --git a/tests/test_cli/test_orm_templates.py b/tests/test_cli/test_orm_templates.py new file mode 100644 index 0000000..f8529ea --- /dev/null +++ b/tests/test_cli/test_orm_templates.py @@ -0,0 +1,60 @@ +from nest.cli.templates.postgres_template import AsyncPostgresqlTemplate +from nest.cli.templates.mysql_template import AsyncMySQLTemplate, MySQLTemplate +from nest.cli.templates.sqlite_template import SQLiteTemplate + + +def test_sync_orm_app_template_uses_database_module_for_root(): + template = SQLiteTemplate("book") + + app_file = template.app_file() + config_file = template.config_file() + service_file = template.service_file() + entity_file = template.entity_file() + + assert "from nest.core.database import DatabaseModule" in app_file + assert "from .config import DATABASE_CONFIG" in app_file + assert "DatabaseModule.for_root(**DATABASE_CONFIG)" in app_file + assert "create_all=True" in config_file + assert "config.create_all" not in app_file + assert "OrmProvider" not in config_file + assert "DATABASE_CONFIG = dict(" in config_file + assert "from nest.core.database import DatabaseService" in service_file + assert "def __init__(self, db: DatabaseService):" in service_file + assert "with self.db.session() as session:" in service_file + assert "from src.config import config" not in service_file + assert "from nest.core.database import Base" in entity_file + + +def test_async_orm_template_uses_injected_database_service(): + template = AsyncPostgresqlTemplate("book") + + app_file = template.app_file() + config_file = template.config_file() + service_file = template.service_file() + controller_file = template.controller_file() + + assert "DatabaseModule.for_root(**DATABASE_CONFIG)" in app_file + assert '"async_mode": True' in config_file + assert '"create_all": True' in config_file + assert "AsyncOrmProvider" not in config_file + assert "from nest.core.database import DatabaseService" in service_file + assert "def __init__(self, db: DatabaseService):" in service_file + assert "async with self.db.session() as session:" in service_file + assert "Depends(config.get_db)" not in controller_file + assert "AsyncSession" not in controller_file + + +def test_orm_template_requirements_include_sqlalchemy_runtime(): + sync_requirements = SQLiteTemplate("book").requirements_file() + async_requirements = AsyncPostgresqlTemplate("book").requirements_file() + + assert "sqlalchemy" in sync_requirements.lower() + assert "sqlalchemy" in async_requirements.lower() + + +def test_mysql_orm_templates_default_missing_port_environment_variables(): + sync_config = MySQLTemplate("book").config_file() + async_config = AsyncMySQLTemplate("book").config_file() + + assert 'os.getenv("MYSQL_PORT", 3306)' in sync_config + assert 'os.getenv("MYSQL_PORT", 3306)' in async_config diff --git a/tests/test_core/test_database/test_database_module.py b/tests/test_core/test_database/test_database_module.py new file mode 100644 index 0000000..5424d01 --- /dev/null +++ b/tests/test_core/test_database/test_database_module.py @@ -0,0 +1,301 @@ +import asyncio + +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import Column, Integer, String, inspect, select +from sqlalchemy.orm import DeclarativeBase + +from nest.core import Controller, Get, Injectable, Module, PyNestFactory +from nest.core.database import ( + DATABASE_ENGINE, + DATABASE_OPTIONS, + DATABASE_SESSION_FACTORY, + DatabaseModule, + DatabaseOptions, + DatabaseService, +) + + +def test_database_module_for_root_registers_core_providers(tmp_path): + class LocalBase(DeclarativeBase): + pass + + database_name = str(tmp_path / "providers") + configured_database_module = DatabaseModule.for_root( + driver="sqlite", + database=database_name, + base=LocalBase, + create_all=False, + ) + + @Module(imports=[configured_database_module]) + class AppModule: + pass + + app = PyNestFactory.create(AppModule) + options = app.container.get(DATABASE_OPTIONS) + engine = app.container.get(DATABASE_ENGINE) + session_factory = app.container.get(DATABASE_SESSION_FACTORY) + database = app.container.get(DatabaseService) + + assert isinstance(options, DatabaseOptions) + assert options.driver == "sqlite" + assert options.config_params == {"db_name": database_name} + assert database.options is options + assert database.engine is engine + assert database.session_factory is session_factory + + asyncio.run(app.close()) + + +def test_database_module_rejects_duplicate_root_registration(tmp_path): + class LocalBase(DeclarativeBase): + pass + + @Module( + imports=[ + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "primary"), + base=LocalBase, + ), + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "secondary"), + base=LocalBase, + ), + ] + ) + class AppModule: + pass + + with pytest.raises(RuntimeError, match="DatabaseModule.for_root"): + PyNestFactory.create(AppModule) + + +def test_database_module_does_not_create_tables_by_default(tmp_path): + class LocalBase(DeclarativeBase): + pass + + class Author(LocalBase): + __tablename__ = "default_authors" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String, nullable=False) + + @Module( + imports=[ + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "default-create-all"), + base=LocalBase, + ) + ] + ) + class AppModule: + pass + + app = PyNestFactory.create(AppModule) + database = app.container.get(DatabaseService) + + assert "default_authors" not in inspect(database.engine).get_table_names() + + asyncio.run(app.close()) + + +def test_database_service_runs_sync_lifecycle_hooks(): + events = [] + + class Metadata: + def create_all(self, bind): + events.append(("create_all", bind)) + + class LocalBase: + metadata = Metadata() + + class Engine: + def dispose(self): + events.append("dispose") + + engine = Engine() + options = DatabaseOptions( + driver="sqlite", + config_params={"db_name": "lifecycle"}, + base=LocalBase, + create_all=True, + ) + service = DatabaseService(options, engine, session_factory=lambda: object()) + + service.on_module_init() + service.on_module_destroy() + + assert events == [("create_all", engine), "dispose"] + + +def test_database_service_session_rolls_back_and_closes_on_error(): + events = [] + + class Session: + def rollback(self): + events.append("rollback") + + def close(self): + events.append("close") + + options = DatabaseOptions( + driver="sqlite", + config_params={"db_name": "sessions"}, + create_all=False, + ) + service = DatabaseService(options, engine=object(), session_factory=Session) + + with pytest.raises(ValueError, match="boom"): + with service.session() as session: + assert isinstance(session, Session) + raise ValueError("boom") + + assert events == ["rollback", "close"] + + +def test_database_service_can_be_replaced_by_app_provider(tmp_path): + class LocalBase(DeclarativeBase): + pass + + class FakeDatabaseService: + def session(self): + return "fake-session" + + fake_database = FakeDatabaseService() + + @Injectable + class UsesDatabase: + def __init__(self, db: DatabaseService): + self.db = db + + @Module( + imports=[ + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "replace"), + base=LocalBase, + create_all=False, + ) + ], + providers=[ + {"provide": DatabaseService, "useValue": fake_database}, + UsesDatabase, + ], + ) + class AppModule: + pass + + app = PyNestFactory.create(AppModule) + + assert app.container.get(DatabaseService) is fake_database + assert app.container.get(UsesDatabase).db is fake_database + + asyncio.run(app.close()) + + +def test_database_module_powers_feature_module_through_http_e2e(tmp_path): + class LocalBase(DeclarativeBase): + pass + + class Author(LocalBase): + __tablename__ = "authors" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String, nullable=False) + + @Injectable + class AuthorService: + def __init__(self, db: DatabaseService): + self.db = db + + def create_and_list(self): + with self.db.session() as session: + session.add(Author(name="Le Guin")) + session.commit() + + with self.db.session() as session: + authors = session.query(Author).order_by(Author.name).all() + return [author.name for author in authors] + + @Controller("/authors", tag="authors") + class AuthorController: + def __init__(self, service: AuthorService): + self.service = service + + @Get("/") + def list_authors(self): + return {"authors": self.service.create_and_list()} + + @Module(controllers=[AuthorController], providers=[AuthorService]) + class AuthorModule: + pass + + @Module( + imports=[ + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "authors"), + base=LocalBase, + create_all=True, + ), + AuthorModule, + ] + ) + class AppModule: + pass + + app = PyNestFactory.create(AppModule) + database = app.container.get(DatabaseService) + + assert "authors" in inspect(database.engine).get_table_names() + + with TestClient(app.get_server()) as client: + response = client.get("/authors") + + assert response.status_code == 200 + assert response.json() == {"authors": ["Le Guin"]} + + +def test_async_database_module_creates_tables_and_runs_queries(tmp_path): + class LocalBase(DeclarativeBase): + pass + + class Author(LocalBase): + __tablename__ = "async_authors" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String, nullable=False) + + @Module( + imports=[ + DatabaseModule.for_root( + driver="sqlite", + database=str(tmp_path / "async-authors"), + base=LocalBase, + async_mode=True, + create_all=True, + ) + ] + ) + class AppModule: + pass + + app = PyNestFactory.create(AppModule) + database = app.container.get(DatabaseService) + + async def scenario(): + async with database.session() as session: + session.add(Author(name="Butler")) + await session.commit() + + async with database.session() as session: + result = await session.execute(select(Author.name)) + return result.scalars().all() + + assert asyncio.run(scenario()) == ["Butler"] + + asyncio.run(app.close()) From edac188693d72041bdd172fbdd3633a5a45dfa70 Mon Sep 17 00:00:00 2001 From: Itay Dar <118370953+ItayTheDar@users.noreply.github.com> Date: Mon, 1 Jun 2026 12:16:20 +0300 Subject: [PATCH 2/2] fix: align ORM templates with DatabaseModule --- nest/cli/templates/mysql_template.py | 48 ++++---- nest/cli/templates/orm_template.py | 86 ++++++++------- nest/cli/templates/postgres_template.py | 48 ++++---- nest/cli/templates/relational_db_template.py | 110 ++++++++----------- nest/cli/templates/sqlite_template.py | 38 ++++--- 5 files changed, 164 insertions(+), 166 deletions(-) diff --git a/nest/cli/templates/mysql_template.py b/nest/cli/templates/mysql_template.py index fc19afc..074f2b9 100644 --- a/nest/cli/templates/mysql_template.py +++ b/nest/cli/templates/mysql_template.py @@ -12,27 +12,27 @@ def __init__(self, module_name: str): ) def config_file(self): - return """from nest.core.database.orm_provider import OrmProvider -import os + return """import os from dotenv import load_dotenv load_dotenv() -config = OrmProvider( - db_type="mysql", - config_params=dict( - host=os.getenv("MYSQL_HOST"), - db_name=os.getenv("MYSQL_DB_NAME"), - user=os.getenv("MYSQL_USER"), - password=os.getenv("MYSQL_PASSWORD"), - port=int(os.getenv("MYSQL_PORT")), - ) +DATABASE_CONFIG = dict( + driver="mysql", + host=os.getenv("MYSQL_HOST"), + database=os.getenv("MYSQL_DB_NAME"), + user=os.getenv("MYSQL_USER"), + password=os.getenv("MYSQL_PASSWORD"), + port=int(os.getenv("MYSQL_PORT", 3306)), + create_all=True, ) """ def requirements_file(self): return f"""pynest-api +sqlalchemy>=2.0.36,<3.0.0 mysql-connector-python==8.2.0 +python-dotenv>=1.0.1,<2.0.0 """ @@ -44,25 +44,27 @@ def __init__(self, module_name: str): ) def config_file(self): - return """from nest.core.database.orm_provider import AsyncOrmProvider -import os + return """import os from dotenv import load_dotenv load_dotenv() -config = AsyncOrmProvider( - db_type="mysql", - config_params=dict( - host=os.getenv("MYSQL_HOST"), - db_name=os.getenv("MYSQL_DB_NAME"), - user=os.getenv("MYSQL_USER"), - password=os.getenv("MYSQL_PASSWORD"), - port=int(os.getenv("MYSQL_PORT")), - ) -) +DATABASE_CONFIG = { + "driver": "mysql", + "host": os.getenv("MYSQL_HOST"), + "database": os.getenv("MYSQL_DB_NAME"), + "user": os.getenv("MYSQL_USER"), + "password": os.getenv("MYSQL_PASSWORD"), + "port": int(os.getenv("MYSQL_PORT", 3306)), + "async_mode": True, + "create_all": True, +} """ def requirements_file(self): return f"""pynest-api +sqlalchemy>=2.0.36,<3.0.0 aiomysql==0.2.0 +greenlet>=3.1.1,<4.0.0 +python-dotenv>=1.0.1,<2.0.0 """ diff --git a/nest/cli/templates/orm_template.py b/nest/cli/templates/orm_template.py index c85ce29..13527de 100644 --- a/nest/cli/templates/orm_template.py +++ b/nest/cli/templates/orm_template.py @@ -14,12 +14,17 @@ def __init__(self, module_name: str, db_type: Database): def app_file(self): return f"""from nest.core import PyNestFactory, Module -from .config import config +from nest.core.database import DatabaseModule +from .config import DATABASE_CONFIG from .app_controller import AppController from .app_service import AppService -@Module(imports=[], controllers=[AppController], providers=[AppService]) +@Module( + imports=[DatabaseModule.for_root(**DATABASE_CONFIG)], + controllers=[AppController], + providers=[AppService], +) class AppModule: pass @@ -33,10 +38,6 @@ class AppModule: ) http_server = app.get_server() - -@http_server.on_event("startup") -def startup(): - config.create_all() """ @abstractmethod @@ -78,11 +79,11 @@ class {self.capitalized_module_name}(BaseModel): """ def entity_file(self): - return f"""from src.config import config + return f"""from nest.core.database import Base from sqlalchemy import Column, Integer, String, Float -class {self.capitalized_module_name}(config.Base): +class {self.capitalized_module_name}(Base): __tablename__ = "{self.module_name}" id = Column(Integer, primary_key=True, autoincrement=True) @@ -93,20 +94,20 @@ class {self.capitalized_module_name}(config.Base): def service_file(self): return f"""from .{self.module_name}_model import {self.capitalized_module_name} from .{self.module_name}_entity import {self.capitalized_module_name} as {self.capitalized_module_name}Entity -from src.config import config from nest.core.decorators.database import db_request_handler from nest.core import Injectable +from nest.core.database import DatabaseService @Injectable class {self.capitalized_module_name}Service: - def __init__(self): - self.config = config + def __init__(self, db: DatabaseService): + self.db = db @db_request_handler def add_{self.module_name}(self, {self.module_name}: {self.capitalized_module_name}): - with self.config.get_session() as session: + with self.db.session() as session: new_{self.module_name} = {self.capitalized_module_name}Entity( **{self.module_name}.dict() ) @@ -116,7 +117,7 @@ def add_{self.module_name}(self, {self.module_name}: {self.capitalized_module_na @db_request_handler def get_{self.module_name}(self): - with self.config.get_session() as session: + with self.db.session() as session: return session.query({self.capitalized_module_name}Entity).all() """ @@ -208,12 +209,17 @@ def generate_project(self, project_name: str): class AsyncORMTemplate(ORMTemplate, ABC): def app_file(self): return f"""from nest.core import PyNestFactory, Module -from .config import config +from nest.core.database import DatabaseModule +from .config import DATABASE_CONFIG from .app_controller import AppController from .app_service import AppService -@Module(imports=[], controllers=[AppController], providers=[AppService]) +@Module( + imports=[DatabaseModule.for_root(**DATABASE_CONFIG)], + controllers=[AppController], + providers=[AppService], +) class AppModule: pass @@ -227,10 +233,6 @@ class AppModule: ) http_server = app.get_server() - -@http_server.on_event("startup") -async def startup(): - await config.create_all() """ @@ -243,12 +245,12 @@ def requirements_file(self): pass def entity_file(self): - return f"""from src.config import config + return f"""from nest.core.database import Base from sqlalchemy import Integer, String from sqlalchemy.orm import Mapped, mapped_column -class {self.capitalized_module_name}(config.Base): +class {self.capitalized_module_name}(Base): __tablename__ = "{self.module_name}" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) @@ -261,34 +263,36 @@ def service_file(self): from .{self.module_name}_entity import {self.capitalized_module_name} as {self.capitalized_module_name}Entity from nest.core.decorators.database import async_db_request_handler from nest.core import Injectable +from nest.core.database import DatabaseService from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession @Injectable class {self.capitalized_module_name}Service: + def __init__(self, db: DatabaseService): + self.db = db + @async_db_request_handler - async def add_{self.module_name}(self, {self.module_name}: {self.capitalized_module_name}, session: AsyncSession): - new_{self.module_name} = {self.capitalized_module_name}Entity( - **{self.module_name}.dict() - ) - session.add(new_{self.module_name}) - await session.commit() - return new_{self.module_name}.id + async def add_{self.module_name}(self, {self.module_name}: {self.capitalized_module_name}): + async with self.db.session() as session: + new_{self.module_name} = {self.capitalized_module_name}Entity( + **{self.module_name}.dict() + ) + session.add(new_{self.module_name}) + await session.commit() + return new_{self.module_name}.id @async_db_request_handler - async def get_{self.module_name}(self, session: AsyncSession): - query = select({self.capitalized_module_name}Entity) - result = await session.execute(query) - return result.scalars().all() + async def get_{self.module_name}(self): + async with self.db.session() as session: + query = select({self.capitalized_module_name}Entity) + result = await session.execute(query) + return result.scalars().all() """ def controller_file(self): - return f"""from nest.core import Controller, Get, Post, Depends -from sqlalchemy.ext.asyncio import AsyncSession -from src.config import config - + return f"""from nest.core import Controller, Get, Post from .{self.module_name}_service import {self.capitalized_module_name}Service from .{self.module_name}_model import {self.capitalized_module_name} @@ -301,12 +305,12 @@ def __init__(self, {self.module_name}_service: {self.capitalized_module_name}Ser self.{self.module_name}_service = {self.module_name}_service @Get("/") - async def get_{self.module_name}(self, session: AsyncSession = Depends(config.get_db)): - return await self.{self.module_name}_service.get_{self.module_name}(session) + async def get_{self.module_name}(self): + return await self.{self.module_name}_service.get_{self.module_name}() @Post("/") - async def add_{self.module_name}(self, {self.module_name}: {self.capitalized_module_name}, session: AsyncSession = Depends(config.get_db)): - return await self.{self.module_name}_service.add_{self.module_name}({self.module_name}, session) + async def add_{self.module_name}(self, {self.module_name}: {self.capitalized_module_name}): + return await self.{self.module_name}_service.add_{self.module_name}({self.module_name}) """ def settings_file(self): diff --git a/nest/cli/templates/postgres_template.py b/nest/cli/templates/postgres_template.py index 191cc95..98d142b 100644 --- a/nest/cli/templates/postgres_template.py +++ b/nest/cli/templates/postgres_template.py @@ -12,27 +12,27 @@ def __init__(self, module_name: str): ) def config_file(self): - return """from nest.core.database.orm_provider import OrmProvider -import os + return """import os from dotenv import load_dotenv load_dotenv() -config = OrmProvider( - db_type="postgresql", - config_params=dict( - host=os.getenv("POSTGRESQL_HOST", "localhost"), - db_name=os.getenv("POSTGRESQL_DB_NAME", "default_nest_db"), - user=os.getenv("POSTGRESQL_USER", "postgres"), - password=os.getenv("POSTGRESQL_PASSWORD", "postgres"), - port=int(os.getenv("POSTGRESQL_PORT", 5432)), - ) +DATABASE_CONFIG = dict( + driver="postgresql", + host=os.getenv("POSTGRESQL_HOST", "localhost"), + database=os.getenv("POSTGRESQL_DB_NAME", "default_nest_db"), + user=os.getenv("POSTGRESQL_USER", "postgres"), + password=os.getenv("POSTGRESQL_PASSWORD", "postgres"), + port=int(os.getenv("POSTGRESQL_PORT", 5432)), + create_all=True, ) """ def requirements_file(self): return f"""pynest-api +sqlalchemy>=2.0.36,<3.0.0 psycopg2==2.9.6 +python-dotenv>=1.0.1,<2.0.0 """ @@ -44,25 +44,27 @@ def __init__(self, module_name: str): ) def config_file(self): - return """from nest.core.database.orm_provider import AsyncOrmProvider -import os + return """import os from dotenv import load_dotenv load_dotenv() -config = AsyncOrmProvider( - db_type="postgresql", - config_params=dict( - host=os.getenv("POSTGRESQL_HOST", "localhost"), - db_name=os.getenv("POSTGRESQL_DB_NAME", "default_nest_db"), - user=os.getenv("POSTGRESQL_USER", "postgres"), - password=os.getenv("POSTGRESQL_PASSWORD", "postgres"), - port=int(os.getenv("POSTGRESQL_PORT", 5432)), - ) -) +DATABASE_CONFIG = { + "driver": "postgresql", + "host": os.getenv("POSTGRESQL_HOST", "localhost"), + "database": os.getenv("POSTGRESQL_DB_NAME", "default_nest_db"), + "user": os.getenv("POSTGRESQL_USER", "postgres"), + "password": os.getenv("POSTGRESQL_PASSWORD", "postgres"), + "port": int(os.getenv("POSTGRESQL_PORT", 5432)), + "async_mode": True, + "create_all": True, +} """ def requirements_file(self): return f"""pynest-api +sqlalchemy>=2.0.36,<3.0.0 asyncpg==0.29.0 +greenlet>=3.1.1,<4.0.0 +python-dotenv>=1.0.1,<2.0.0 """ diff --git a/nest/cli/templates/relational_db_template.py b/nest/cli/templates/relational_db_template.py index f915076..1e616a0 100644 --- a/nest/cli/templates/relational_db_template.py +++ b/nest/cli/templates/relational_db_template.py @@ -11,52 +11,55 @@ def __init__(self, name, db_type): def generate_service_file(self) -> str: return f"""from src.{self.name}.{self.name}_model import {self.capitalized_name} from src.{self.name}.{self.name}_entity import {self.capitalized_name} as {self.capitalized_name}Entity -from orm_config import config -from nest.core.decorators import db_request_handler -from functools import lru_cache +from nest.core import Injectable +from nest.core.database import DatabaseService +from nest.core.decorators.database import db_request_handler -@lru_cache() +@Injectable class {self.capitalized_name}Service: - def __init__(self): - self.orm_config = config - self.session = self.orm_config.get_db() + def __init__(self, db: DatabaseService): + self.db = db @db_request_handler def add_{self.name}(self, {self.name}: {self.capitalized_name}): - new_{self.name} = {self.capitalized_name}Entity( - **{self.name}.dict() - ) - self.session.add(new_{self.name}) - self.session.commit() - return new_{self.name}.id + with self.db.session() as session: + new_{self.name} = {self.capitalized_name}Entity( + **{self.name}.dict() + ) + session.add(new_{self.name}) + session.commit() + return new_{self.name}.id @db_request_handler def get_{self.name}(self): - return self.session.query({self.capitalized_name}Entity).all() + with self.db.session() as session: + return session.query({self.capitalized_name}Entity).all() @db_request_handler def delete_{self.name}(self, {self.name}_id: int): - self.session.query({self.capitalized_name}Entity).filter_by(id={self.name}_id).delete() - self.session.commit() - return {self.name}_id + with self.db.session() as session: + session.query({self.capitalized_name}Entity).filter_by(id={self.name}_id).delete() + session.commit() + return {self.name}_id @db_request_handler def update_{self.name}(self, {self.name}_id: int, {self.name}: {self.capitalized_name}): - self.session.query({self.capitalized_name}Entity).filter_by(id={self.name}_id).update( - {self.name}.dict() - ) - self.session.commit() - return {self.name}_id + with self.db.session() as session: + session.query({self.capitalized_name}Entity).filter_by(id={self.name}_id).update( + {self.name}.dict() + ) + session.commit() + return {self.name}_id """ def generate_entity_file(self) -> str: - return f"""from orm_config import config + return f"""from nest.core.database import Base from sqlalchemy import Column, Integer, String, Float -class {self.capitalized_name}(config.Base): +class {self.capitalized_name}(Base): __tablename__ = "{self.name}" id = Column(Integer, primary_key=True, autoincrement=True) @@ -64,58 +67,41 @@ class {self.capitalized_name}(config.Base): """ def generate_requirements_file(self) -> str: - return f"""anyio==3.6.2 -click==8.1.3 -fastapi==0.95.1 -fastapi-utils==0.2.1 -greenlet==2.0.2 -h11==0.14.0 -idna==3.4 -pydantic==1.10.7 -python-dotenv==1.0.0 -sniffio==1.3.0 -SQLAlchemy==1.4.48 -starlette==0.26.1 -typing_extensions==4.5.0 -uvicorn==0.22.0 -pynest-api=={version} + return f"""pynest-api=={version} +sqlalchemy>=2.0.36,<3.0.0 +python-dotenv>=1.0.1,<2.0.0 """ def generate_dockerfile(self) -> str: pass def generate_orm_config_file(self) -> str: - base_template = f"""from nest.core.database.base_orm import OrmService -import os + base_template = f"""import os from dotenv import load_dotenv load_dotenv() - - """ if self.db_type == "sqlite": return f"""{base_template} - config = OrmService( - db_type="{self.db_type}", - config_params=dict( - db_name=os.getenv("SQLITE_DB_NAME", "{self.name}_db"), - ) - ) +DATABASE_CONFIG = dict( + driver="{self.db_type}", + database=os.getenv("SQLITE_DB_NAME", "{self.name}_db"), + create_all=True, +) + """ + default_port = 5432 if self.db_type == "postgresql" else 3306 + return f"""{base_template} +DATABASE_CONFIG = dict( + driver="{self.db_type}", + host=os.getenv("{self.db_type.upper()}_HOST", "localhost"), + database=os.getenv("{self.db_type.upper()}_DB_NAME", "{self.name}_db"), + user=os.getenv("{self.db_type.upper()}_USER"), + password=os.getenv("{self.db_type.upper()}_PASSWORD"), + port=int(os.getenv("{self.db_type.upper()}_PORT", {default_port})), + create_all=True, +) """ - else: - return f"""{base_template} - config = OrmService( - db_type="{self.db_type}", - config_params=dict( - host=os.getenv("{self.db_type.upper()}_HOST"), - db_name=os.getenv("{self.db_type.upper()}_DB_NAME"), - user=os.getenv("{self.db_type.upper()}_USER"), - password=os.getenv("{self.db_type.upper()}_PASSWORD"), - port=int(os.getenv("{self.db_type.upper()}_PORT")), - ) - ) - """ if __name__ == "__main__": diff --git a/nest/cli/templates/sqlite_template.py b/nest/cli/templates/sqlite_template.py index 461a7d9..8092284 100644 --- a/nest/cli/templates/sqlite_template.py +++ b/nest/cli/templates/sqlite_template.py @@ -12,22 +12,23 @@ def __init__(self, module_name: str): ) def config_file(self): - return """from nest.core.database.orm_provider import OrmProvider -import os + return """import os from dotenv import load_dotenv load_dotenv() -config = OrmProvider( - db_type="sqlite", - config_params=dict( - db_name=os.getenv("SQLITE_DB_NAME", "default_nest_db"), - ) +DATABASE_CONFIG = dict( + driver="sqlite", + database=os.getenv("SQLITE_DB_NAME", "default_nest_db"), + create_all=True, ) """ def requirements_file(self): - return f"""pynest-api""" + return f"""pynest-api +sqlalchemy>=2.0.36,<3.0.0 +python-dotenv>=1.0.1,<2.0.0 +""" def docker_file(self): return """FROM tiangolo/uvicorn-gunicorn-fastapi:python3.11 @@ -45,23 +46,26 @@ def __init__(self, module_name: str): ) def config_file(self): - return """from nest.core.database.orm_provider import AsyncOrmProvider -import os + return """import os from dotenv import load_dotenv load_dotenv() -config = AsyncOrmProvider( - db_type="sqlite", - config_params=dict( - db_name=os.getenv("SQLITE_DB_NAME", "default_nest_db"), - ) -) +DATABASE_CONFIG = { + "driver": "sqlite", + "database": os.getenv("SQLITE_DB_NAME", "default_nest_db"), + "async_mode": True, + "create_all": True, +} """ def requirements_file(self): return f"""pynest-api -aiosqlite==0.19.0""" +sqlalchemy>=2.0.36,<3.0.0 +aiosqlite==0.19.0 +greenlet>=3.1.1,<4.0.0 +python-dotenv>=1.0.1,<2.0.0 +""" def docker_file(self): return """FROM tiangolo/uvicorn-gunicorn-fastapi:python3.11