99
1010from sqlalchemy_database ._abc_async_database import AbcAsyncDatabase
1111
12- _T = TypeVar ("_T" )
1312_P = ParamSpec ("_P" )
13+ _T = TypeVar ("_T" )
14+ _R = TypeVar ("_R" )
15+
1416_ExecuteParams = Union [Mapping [Any , Any ], Sequence [Mapping [Any , Any ]]]
1517_ExecuteOptions = Mapping [Any , Any ]
1618
@@ -38,6 +40,7 @@ async def execute(
3840 self ,
3941 statement : Executable ,
4042 params : Optional [_ExecuteParams ] = None ,
43+ * ,
4144 execution_options : Optional [_ExecuteOptions ] = None ,
4245 bind_arguments : Optional [Mapping [str , Any ]] = None ,
4346 commit : bool = False ,
@@ -56,6 +59,7 @@ async def scalar(
5659 self ,
5760 statement : Executable ,
5861 params : Optional [_ExecuteParams ] = None ,
62+ * ,
5963 execution_options : Optional [_ExecuteOptions ] = None ,
6064 bind_arguments : Optional [Mapping [str , Any ]] = None ,
6165 ** kw : Any ,
@@ -73,6 +77,7 @@ async def scalars_all(
7377 self ,
7478 statement : Executable ,
7579 params : Optional [_ExecuteParams ] = None ,
80+ * ,
7681 execution_options : Optional [_ExecuteOptions ] = None ,
7782 ** kw : Any ,
7883 ) -> List [Any ]:
@@ -89,6 +94,7 @@ async def get(
8994 self ,
9095 entity : Type [_T ],
9196 ident : Any ,
97+ * ,
9298 options : Optional [Sequence [Any ]] = None ,
9399 populate_existing : bool = False ,
94100 with_for_update : Optional [Any ] = None ,
@@ -114,15 +120,19 @@ async def run_sync(
114120 self ,
115121 fn : Callable [[Concatenate [Union [Session , Connection ], _P ]], _T ],
116122 * args : _P .args ,
123+ commit : bool = True ,
124+ on_close_pre : Callable [[_T ], _R ] = None ,
117125 is_session : bool = False ,
118126 ** kwargs : _P .kwargs
119- ) -> _T :
120- if is_session :
121- async with self .session_maker () as session :
122- async with session .begin ():
123- return await session .run_sync (fn , * args , ** kwargs )
124- async with self .engine .begin () as conn :
125- return await conn .run_sync (fn , * args , ** kwargs )
127+ ) -> Union [_T ,_R ]:
128+ maker = self .session_maker if is_session else self .engine .connect
129+ async with maker () as conn :
130+ result = await conn .run_sync (fn , * args , ** kwargs )
131+ if commit :
132+ await conn .commit ()
133+ if on_close_pre :
134+ result = on_close_pre (result )
135+ return result
126136
127137
128138class Database (AbcAsyncDatabase ):
@@ -147,6 +157,7 @@ def execute(
147157 self ,
148158 statement : Executable ,
149159 params : Optional [_ExecuteParams ] = None ,
160+ * ,
150161 execution_options : Optional [_ExecuteOptions ] = None ,
151162 bind_arguments : Optional [Mapping [str , Any ]] = None ,
152163 commit : bool = False ,
@@ -165,6 +176,7 @@ def scalar(
165176 self ,
166177 statement : Executable ,
167178 params : Optional [_ExecuteParams ] = None ,
179+ * ,
168180 execution_options : Optional [_ExecuteOptions ] = None ,
169181 bind_arguments : Optional [Mapping [str , Any ]] = None ,
170182 ** kw : Any ,
@@ -182,6 +194,7 @@ def scalars_all(
182194 self ,
183195 statement : Executable ,
184196 params : Optional [_ExecuteParams ] = None ,
197+ * ,
185198 execution_options : Optional [_ExecuteOptions ] = None ,
186199 bind_arguments : Optional [Mapping [str , Any ]] = None ,
187200 ** kw : Any ,
@@ -199,6 +212,7 @@ def get(
199212 self ,
200213 entity : Type [_T ],
201214 ident : Any ,
215+ * ,
202216 options : Optional [Sequence [Any ]] = None ,
203217 populate_existing : bool = False ,
204218 with_for_update : Optional [Any ] = None ,
@@ -224,12 +238,16 @@ def run_sync(
224238 self ,
225239 fn : Callable [[Concatenate [Union [Session , Connection ], _P ]], _T ],
226240 * args : _P .args ,
241+ commit : bool = True ,
242+ on_close_pre : Callable [[_T ], _R ] = None ,
227243 is_session : bool = False ,
228244 ** kwargs : _P .kwargs
229- ) -> _T :
230- if is_session :
231- with self .session_maker () as session :
232- with session .begin ():
233- return fn (session , * args , ** kwargs )
234- with self .engine .begin () as conn :
235- return fn (conn , * args , ** kwargs )
245+ ) -> Union [_T , _R ]:
246+ maker = self .session_maker if is_session else self .engine .connect
247+ with maker () as conn :
248+ result = fn (conn , * args , ** kwargs )
249+ if commit :
250+ conn .commit ()
251+ if on_close_pre :
252+ result = on_close_pre (result )
253+ return result
0 commit comments