4040class AsyncDatabase (AbcAsyncDatabase ):
4141 """`sqlalchemy` asynchronous database client"""
4242
43- def __init__ (self , engine : AsyncEngine , ** session_options ):
43+ def __init__ (
44+ self ,
45+ engine : AsyncEngine ,
46+ commit_on_close : bool = False ,
47+ ** session_options ,
48+ ):
4449 """
4550 Initialize the client through the asynchronous engine
4651 Args:
4752 engine: Asynchronous Engine
53+ commit_on_close: Whether to commit the session when the context manager or session generator exits.
4854 **session_options: The default `session` initialization parameters
4955 """
5056 super ().__init__ ()
@@ -57,6 +63,8 @@ def __init__(self, engine: AsyncEngine, **session_options):
5763 await conn.run_sync(SQLModel.metadata.create_all)
5864 ```
5965 """
66+ self .commit_on_close : bool = commit_on_close
67+ """Whether to commit the session when the context manager or session generator exits."""
6068 session_options .setdefault ("class_" , AsyncSession )
6169 self .session_maker : Callable [..., AsyncSession ] = sessionmaker (self .engine , ** session_options )
6270 """`sqlalchemy` session factory function
@@ -97,11 +105,14 @@ def __call__(self):
97105 return AsyncSessionContextVarManager (self )
98106
99107 @classmethod
100- def create (cls , url : str , * , session_options : Mapping [str , Any ] = None , ** kwargs ) -> "AsyncDatabase" :
108+ def create (
109+ cls , url : str , * , commit_on_close : bool = False , session_options : Mapping [str , Any ] = None , ** kwargs
110+ ) -> "AsyncDatabase" :
101111 """
102112 Initialize the client with a database connection string
103113 Args:
104114 url: Asynchronous database connection string
115+ commit_on_close: Whether to commit the session when the context manager or session generator exits.
105116 session_options: The default `session` initialization parameters
106117 **kwargs: Asynchronous engine initialization parameters
107118
@@ -111,7 +122,7 @@ def create(cls, url: str, *, session_options: Mapping[str, Any] = None, **kwargs
111122 kwargs .setdefault ("future" , True )
112123 engine = create_async_engine (url , ** kwargs )
113124 session_options = session_options or {}
114- return cls (engine , ** session_options )
125+ return cls (engine , commit_on_close = commit_on_close , ** session_options )
115126
116127 async def session_generator (self ) -> AsyncGenerator [AsyncSession , Any ]:
117128 """AsyncSession Generator, available for FastAPI dependencies.
@@ -125,6 +136,8 @@ async def session_generator(self) -> AsyncGenerator[AsyncSession, Any]:
125136 """
126137 async with self .session_maker () as session :
127138 yield session
139+ if self .commit_on_close :
140+ await session .commit ()
128141
129142 async def _executor_maker (
130143 self , executor : Union [AsyncSession , AsyncConnection , None ] = None , is_session : bool = True
@@ -421,9 +434,10 @@ async def refresh(self, instance, attribute_names=None, with_for_update=None, se
421434class Database (AbcAsyncDatabase ):
422435 """`sqlalchemy` synchronous database client"""
423436
424- def __init__ (self , engine : Engine , ** session_options ):
437+ def __init__ (self , engine : Engine , commit_on_close : bool = False , ** session_options ):
425438 super ().__init__ ()
426439 self .engine : Engine = engine
440+ self .commit_on_close : bool = commit_on_close
427441 session_options .setdefault ("class_" , Session )
428442 self .session_maker : Callable [..., Session ] = sessionmaker (self .engine , ** session_options )
429443 self ._session_context_var : ContextVar [Optional [Session ]] = ContextVar ("_session_context_var" , default = None )
@@ -437,7 +451,9 @@ def __call__(self):
437451 return SessionContextVarManager (self )
438452
439453 @classmethod
440- def create (cls , url : str , * , session_options : Optional [Mapping [str , Any ]] = None , ** kwargs ) -> "Database" :
454+ def create (
455+ cls , url : str , * , commit_on_close : bool = False , session_options : Optional [Mapping [str , Any ]] = None , ** kwargs
456+ ) -> "Database" :
441457 kwargs .setdefault ("future" , True )
442458 engine = create_engine (url , ** kwargs )
443459 session_options = session_options or {}
@@ -446,6 +462,8 @@ def create(cls, url: str, *, session_options: Optional[Mapping[str, Any]] = None
446462 def session_generator (self ) -> Generator [Session , Any , None ]:
447463 with self .session_maker () as session :
448464 yield session
465+ if self .commit_on_close :
466+ session .commit ()
449467
450468 def _executor_maker (
451469 self , executor : Union [Session , Connection , None ] = None , is_session : bool = True
@@ -626,6 +644,8 @@ async def __aexit__(self, exc_type, exc_value, traceback):
626644 session = self .db ._session_context_var .get ()
627645 if exc_type is not None :
628646 await session .rollback ()
647+ if self .db .commit_on_close :
648+ await session .commit ()
629649 await session .close ()
630650 self .db ._session_context_var .reset (self .token )
631651
@@ -644,6 +664,8 @@ def __exit__(self, exc_type, exc_value, traceback):
644664 session = self .db ._session_context_var .get ()
645665 if exc_type is not None :
646666 session .rollback ()
667+ if self .db .commit_on_close :
668+ session .commit ()
647669 session .close ()
648670 self .db ._session_context_var .reset (self .token )
649671
0 commit comments