Skip to content

Commit 04b2f32

Browse files
committed
feat: Support automatic commit session before closed.
1 parent f330796 commit 04b2f32

2 files changed

Lines changed: 31 additions & 8 deletions

File tree

sqlalchemy_database/_abc_async_database.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ async def asgi_dispatch(self, request, call_next):
4040
app.add_middleware(BaseHTTPMiddleware, dispatch=db.asgi_dispatch)
4141
```
4242
"""
43-
async with self.__call__():
44-
response = await call_next(request)
45-
return response
43+
if self.session is None: # bind session to request
44+
async with self.__call__():
45+
return await call_next(request)
46+
return await call_next(request)

sqlalchemy_database/database.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,17 @@
4040
class AsyncDatabase(AbcAsyncDatabase):
4141
"""`sqlalchemy` asynchronous database client"""
4242

43-
def __init__(self, engine: AsyncEngine, **session_options):
43+
def __init__(
44+
self,
45+
engine: AsyncEngine,
46+
commit_on_close: bool = False,
47+
**session_options,
48+
):
4449
"""
4550
Initialize the client through the asynchronous engine
4651
Args:
4752
engine: Asynchronous Engine
53+
commit_on_close: Whether to commit the session when the context manager or session generator exits.
4854
**session_options: The default `session` initialization parameters
4955
"""
5056
super().__init__()
@@ -57,6 +63,8 @@ def __init__(self, engine: AsyncEngine, **session_options):
5763
await conn.run_sync(SQLModel.metadata.create_all)
5864
```
5965
"""
66+
self.commit_on_close: bool = commit_on_close
67+
"""Whether to commit the session when the context manager or session generator exits."""
6068
session_options.setdefault("class_", AsyncSession)
6169
self.session_maker: Callable[..., AsyncSession] = sessionmaker(self.engine, **session_options)
6270
"""`sqlalchemy` session factory function
@@ -97,11 +105,14 @@ def __call__(self):
97105
return AsyncSessionContextVarManager(self)
98106

99107
@classmethod
100-
def create(cls, url: str, *, session_options: Mapping[str, Any] = None, **kwargs) -> "AsyncDatabase":
108+
def create(
109+
cls, url: str, *, commit_on_close: bool = False, session_options: Mapping[str, Any] = None, **kwargs
110+
) -> "AsyncDatabase":
101111
"""
102112
Initialize the client with a database connection string
103113
Args:
104114
url: Asynchronous database connection string
115+
commit_on_close: Whether to commit the session when the context manager or session generator exits.
105116
session_options: The default `session` initialization parameters
106117
**kwargs: Asynchronous engine initialization parameters
107118
@@ -111,7 +122,7 @@ def create(cls, url: str, *, session_options: Mapping[str, Any] = None, **kwargs
111122
kwargs.setdefault("future", True)
112123
engine = create_async_engine(url, **kwargs)
113124
session_options = session_options or {}
114-
return cls(engine, **session_options)
125+
return cls(engine, commit_on_close=commit_on_close, **session_options)
115126

116127
async def session_generator(self) -> AsyncGenerator[AsyncSession, Any]:
117128
"""AsyncSession Generator, available for FastAPI dependencies.
@@ -125,6 +136,8 @@ async def session_generator(self) -> AsyncGenerator[AsyncSession, Any]:
125136
"""
126137
async with self.session_maker() as session:
127138
yield session
139+
if self.commit_on_close:
140+
await session.commit()
128141

129142
async def _executor_maker(
130143
self, executor: Union[AsyncSession, AsyncConnection, None] = None, is_session: bool = True
@@ -421,9 +434,10 @@ async def refresh(self, instance, attribute_names=None, with_for_update=None, se
421434
class Database(AbcAsyncDatabase):
422435
"""`sqlalchemy` synchronous database client"""
423436

424-
def __init__(self, engine: Engine, **session_options):
437+
def __init__(self, engine: Engine, commit_on_close: bool = False, **session_options):
425438
super().__init__()
426439
self.engine: Engine = engine
440+
self.commit_on_close: bool = commit_on_close
427441
session_options.setdefault("class_", Session)
428442
self.session_maker: Callable[..., Session] = sessionmaker(self.engine, **session_options)
429443
self._session_context_var: ContextVar[Optional[Session]] = ContextVar("_session_context_var", default=None)
@@ -437,7 +451,9 @@ def __call__(self):
437451
return SessionContextVarManager(self)
438452

439453
@classmethod
440-
def create(cls, url: str, *, session_options: Optional[Mapping[str, Any]] = None, **kwargs) -> "Database":
454+
def create(
455+
cls, url: str, *, commit_on_close: bool = False, session_options: Optional[Mapping[str, Any]] = None, **kwargs
456+
) -> "Database":
441457
kwargs.setdefault("future", True)
442458
engine = create_engine(url, **kwargs)
443459
session_options = session_options or {}
@@ -446,6 +462,8 @@ def create(cls, url: str, *, session_options: Optional[Mapping[str, Any]] = None
446462
def session_generator(self) -> Generator[Session, Any, None]:
447463
with self.session_maker() as session:
448464
yield session
465+
if self.commit_on_close:
466+
session.commit()
449467

450468
def _executor_maker(
451469
self, executor: Union[Session, Connection, None] = None, is_session: bool = True
@@ -626,6 +644,8 @@ async def __aexit__(self, exc_type, exc_value, traceback):
626644
session = self.db._session_context_var.get()
627645
if exc_type is not None:
628646
await session.rollback()
647+
if self.db.commit_on_close:
648+
await session.commit()
629649
await session.close()
630650
self.db._session_context_var.reset(self.token)
631651

@@ -644,6 +664,8 @@ def __exit__(self, exc_type, exc_value, traceback):
644664
session = self.db._session_context_var.get()
645665
if exc_type is not None:
646666
session.rollback()
667+
if self.db.commit_on_close:
668+
session.commit()
647669
session.close()
648670
self.db._session_context_var.reset(self.token)
649671

0 commit comments

Comments
 (0)