@@ -143,14 +143,13 @@ async def execute(
143143 kw ['bind_arguments' ] = bind_arguments
144144 else :
145145 executor = await self .engine .connect ()
146- result = await executor .execute (statement , params , execution_options , ** kw ) # type:ignore
147- if on_close_pre :
148- result = on_close_pre (result )
149- if commit and not isinstance (statement , Select ):
150- await executor .commit ()
151- if need_close :
152- await executor .close ()
153- return result
146+ async with ExecutorContextManager (executor , need_close = need_close ) as executor :
147+ result = await executor .execute (statement , params , execution_options , ** kw ) # type:ignore
148+ if on_close_pre :
149+ result = on_close_pre (result )
150+ if commit and not isinstance (statement , Select ):
151+ await executor .commit ()
152+ return result
154153
155154 async def scalar (
156155 self ,
@@ -172,16 +171,15 @@ async def scalar(
172171 executor = self .session_maker ()
173172 else :
174173 executor = session
175- result = await executor .scalar (
176- statement ,
177- params ,
178- execution_options = execution_options ,
179- bind_arguments = bind_arguments ,
180- ** kw ,
181- )
182- if session is None :
183- await executor .close ()
184- return result
174+ async with ExecutorContextManager (executor , need_close = session is None ) as executor :
175+ result = await executor .scalar (
176+ statement ,
177+ params ,
178+ execution_options = execution_options ,
179+ bind_arguments = bind_arguments ,
180+ ** kw ,
181+ )
182+ return result
185183
186184 async def scalars_all (
187185 self ,
@@ -202,15 +200,14 @@ async def scalars_all(
202200 executor = self .session_maker ()
203201 else :
204202 executor = session
205- result = (await executor .scalars (
206- statement ,
207- params ,
208- execution_options = execution_options ,
209- ** kw ,
210- )).all ()
211- if session is None :
212- await executor .close ()
213- return result
203+ async with ExecutorContextManager (executor , need_close = session is None ) as executor :
204+ result = (await executor .scalars (
205+ statement ,
206+ params ,
207+ execution_options = execution_options ,
208+ ** kw ,
209+ )).all ()
210+ return result
214211
215212 async def get (
216213 self ,
@@ -268,17 +265,16 @@ async def get(
268265 executor = self .session_maker ()
269266 else :
270267 executor = session
271- result = await executor .get (
272- entity ,
273- ident ,
274- options = options ,
275- populate_existing = populate_existing ,
276- with_for_update = with_for_update ,
277- identity_token = identity_token ,
278- )
279- if session is None :
280- await executor .close ()
281- return result
268+ async with ExecutorContextManager (executor , need_close = session is None ) as executor :
269+ result = await executor .get (
270+ entity ,
271+ ident ,
272+ options = options ,
273+ populate_existing = populate_existing ,
274+ with_for_update = with_for_update ,
275+ identity_token = identity_token ,
276+ )
277+ return result
282278
283279 async def delete (self , instance : Any ) -> None :
284280 """Deletes an instance object."""
@@ -303,12 +299,11 @@ async def save(
303299 executor = self .session_maker ()
304300 else :
305301 executor = session
306- executor .add_all (instances )
307- await executor .commit ()
308- if refresh :
309- [await executor .refresh (instance ) for instance in instances ]
310- if session is None :
311- await executor .close ()
302+ async with ExecutorContextManager (executor , need_close = session is None ) as executor :
303+ executor .add_all (instances )
304+ await executor .commit ()
305+ if refresh :
306+ [await executor .refresh (instance ) for instance in instances ]
312307
313308 async def run_sync (
314309 self ,
@@ -357,14 +352,13 @@ def get_user(session:Session,id:int):
357352 if executor is None or not isinstance (executor , (AsyncSession , AsyncConnection )):
358353 need_close = True
359354 executor = self .session_maker () if is_session else await self .engine .connect ()
360- result = await executor .run_sync (fn , * args , ** kwargs )
361- if on_close_pre :
362- result = on_close_pre (result )
363- if commit :
364- await executor .commit ()
365- if need_close :
366- await executor .close ()
367- return result
355+ async with ExecutorContextManager (executor , need_close = need_close ) as executor :
356+ result = await executor .run_sync (fn , * args , ** kwargs )
357+ if on_close_pre :
358+ result = on_close_pre (result )
359+ if commit :
360+ await executor .commit ()
361+ return result
368362
369363class Database (AbcAsyncDatabase ):
370364 """`sqlalchemy` synchronous database client
@@ -408,14 +402,13 @@ def execute(
408402 kw ['bind_arguments' ] = bind_arguments
409403 else :
410404 executor = self .engine .connect ()
411- result = executor .execute (statement , params , execution_options , ** kw )
412- if on_close_pre :
413- result = on_close_pre (result )
414- if commit and not isinstance (statement , Select ):
415- executor .commit ()
416- if need_close :
417- executor .close ()
418- return result
405+ with ExecutorContextManager (executor , need_close = need_close ) as executor :
406+ result = executor .execute (statement , params , execution_options , ** kw )
407+ if on_close_pre :
408+ result = on_close_pre (result )
409+ if commit and not isinstance (statement , Select ):
410+ executor .commit ()
411+ return result
419412
420413 def scalar (
421414 self ,
@@ -431,16 +424,15 @@ def scalar(
431424 executor = self .session_maker ()
432425 else :
433426 executor = session
434- result = executor .scalar (
435- statement ,
436- params ,
437- execution_options = execution_options ,
438- bind_arguments = bind_arguments ,
439- ** kw ,
440- )
441- if session is None :
442- executor .close ()
443- return result
427+ with ExecutorContextManager (executor , need_close = session is None ) as executor :
428+ result = executor .scalar (
429+ statement ,
430+ params ,
431+ execution_options = execution_options ,
432+ bind_arguments = bind_arguments ,
433+ ** kw ,
434+ )
435+ return result
444436
445437 def scalars_all (
446438 self ,
@@ -456,16 +448,15 @@ def scalars_all(
456448 executor = self .session_maker ()
457449 else :
458450 executor = session
459- result = executor .scalars (
460- statement ,
461- params ,
462- execution_options = execution_options ,
463- bind_arguments = bind_arguments ,
464- ** kw ,
465- ).all ()
466- if session is None :
467- executor .close ()
468- return result
451+ with ExecutorContextManager (executor , need_close = session is None ) as executor :
452+ result = executor .scalars (
453+ statement ,
454+ params ,
455+ execution_options = execution_options ,
456+ bind_arguments = bind_arguments ,
457+ ** kw ,
458+ ).all ()
459+ return result
469460
470461 def get (
471462 self ,
@@ -483,17 +474,16 @@ def get(
483474 executor = self .session_maker ()
484475 else :
485476 executor = session
486- result = executor .get (
487- entity ,
488- ident ,
489- options = options ,
490- populate_existing = populate_existing ,
491- with_for_update = with_for_update ,
492- identity_token = identity_token ,
493- )
494- if session is None :
495- executor .close ()
496- return result
477+ with ExecutorContextManager (executor , need_close = session is None ) as executor :
478+ result = executor .get (
479+ entity ,
480+ ident ,
481+ options = options ,
482+ populate_existing = populate_existing ,
483+ with_for_update = with_for_update ,
484+ identity_token = identity_token ,
485+ )
486+ return result
497487
498488 def delete (self , instance : Any ) -> None :
499489 with self .session_maker () as session :
@@ -510,12 +500,11 @@ def save(
510500 executor = self .session_maker ()
511501 else :
512502 executor = session
513- executor .add_all (instances )
514- executor .commit ()
515- if refresh :
516- [executor .refresh (instance ) for instance in instances ]
517- if session is None :
518- executor .close ()
503+ with ExecutorContextManager (executor , need_close = session is None ) as executor :
504+ executor .add_all (instances )
505+ executor .commit ()
506+ if refresh :
507+ [executor .refresh (instance ) for instance in instances ]
519508
520509 def run_sync (
521510 self ,
@@ -531,11 +520,35 @@ def run_sync(
531520 if executor is None or not isinstance (executor , (Session , Connection )):
532521 need_close = True
533522 executor = self .session_maker () if is_session else self .engine .connect ()
534- result = fn (executor , * args , ** kwargs )
535- if on_close_pre :
536- result = on_close_pre (result )
537- if commit :
538- executor .commit ()
539- if need_close :
540- executor .close ()
541- return result
523+ with ExecutorContextManager (executor , need_close = need_close ) as executor :
524+ result = fn (executor , * args , ** kwargs )
525+ if on_close_pre :
526+ result = on_close_pre (result )
527+ if commit :
528+ executor .commit ()
529+ return result
530+
531+ class ExecutorContextManager :
532+ """Actuator context manager, optionally closing the executor"""
533+
534+ def __init__ (
535+ self ,
536+ executor : Union [Session , Connection , AsyncSession , AsyncConnection ],
537+ need_close : bool = True
538+ ):
539+ self .executor = executor
540+ self .need_close = need_close
541+
542+ def __enter__ (self ):
543+ return self .executor
544+
545+ def __exit__ (self , exc_type , exc_val , exc_tb ):
546+ if self .need_close :
547+ self .executor .close ()
548+
549+ async def __aenter__ (self ):
550+ return self .executor
551+
552+ async def __aexit__ (self , exc_type , exc_val , exc_tb ):
553+ if self .need_close :
554+ await self .executor .close ()
0 commit comments