Skip to content

Commit 7afc886

Browse files
committed
perf: Use the context manager to ensure that the executor shuts down properly
1 parent b081bf3 commit 7afc886

1 file changed

Lines changed: 118 additions & 105 deletions

File tree

sqlalchemy_database/database.py

Lines changed: 118 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,13 @@ async def execute(
143143
kw['bind_arguments'] = bind_arguments
144144
else:
145145
executor = await self.engine.connect()
146-
result = await executor.execute(statement, params, execution_options, **kw) # type:ignore
147-
if on_close_pre:
148-
result = on_close_pre(result)
149-
if commit and not isinstance(statement, Select):
150-
await executor.commit()
151-
if need_close:
152-
await executor.close()
153-
return result
146+
async with ExecutorContextManager(executor, need_close = need_close) as executor:
147+
result = await executor.execute(statement, params, execution_options, **kw) # type:ignore
148+
if on_close_pre:
149+
result = on_close_pre(result)
150+
if commit and not isinstance(statement, Select):
151+
await executor.commit()
152+
return result
154153

155154
async def scalar(
156155
self,
@@ -172,16 +171,15 @@ async def scalar(
172171
executor = self.session_maker()
173172
else:
174173
executor = session
175-
result = await executor.scalar(
176-
statement,
177-
params,
178-
execution_options = execution_options,
179-
bind_arguments = bind_arguments,
180-
**kw,
181-
)
182-
if session is None:
183-
await executor.close()
184-
return result
174+
async with ExecutorContextManager(executor, need_close = session is None) as executor:
175+
result = await executor.scalar(
176+
statement,
177+
params,
178+
execution_options = execution_options,
179+
bind_arguments = bind_arguments,
180+
**kw,
181+
)
182+
return result
185183

186184
async def scalars_all(
187185
self,
@@ -202,15 +200,14 @@ async def scalars_all(
202200
executor = self.session_maker()
203201
else:
204202
executor = session
205-
result = (await executor.scalars(
206-
statement,
207-
params,
208-
execution_options = execution_options,
209-
**kw,
210-
)).all()
211-
if session is None:
212-
await executor.close()
213-
return result
203+
async with ExecutorContextManager(executor, need_close = session is None) as executor:
204+
result = (await executor.scalars(
205+
statement,
206+
params,
207+
execution_options = execution_options,
208+
**kw,
209+
)).all()
210+
return result
214211

215212
async def get(
216213
self,
@@ -268,17 +265,16 @@ async def get(
268265
executor = self.session_maker()
269266
else:
270267
executor = session
271-
result = await executor.get(
272-
entity,
273-
ident,
274-
options = options,
275-
populate_existing = populate_existing,
276-
with_for_update = with_for_update,
277-
identity_token = identity_token,
278-
)
279-
if session is None:
280-
await executor.close()
281-
return result
268+
async with ExecutorContextManager(executor, need_close = session is None) as executor:
269+
result = await executor.get(
270+
entity,
271+
ident,
272+
options = options,
273+
populate_existing = populate_existing,
274+
with_for_update = with_for_update,
275+
identity_token = identity_token,
276+
)
277+
return result
282278

283279
async def delete(self, instance: Any) -> None:
284280
"""Deletes an instance object."""
@@ -303,12 +299,11 @@ async def save(
303299
executor = self.session_maker()
304300
else:
305301
executor = session
306-
executor.add_all(instances)
307-
await executor.commit()
308-
if refresh:
309-
[await executor.refresh(instance) for instance in instances]
310-
if session is None:
311-
await executor.close()
302+
async with ExecutorContextManager(executor, need_close = session is None) as executor:
303+
executor.add_all(instances)
304+
await executor.commit()
305+
if refresh:
306+
[await executor.refresh(instance) for instance in instances]
312307

313308
async def run_sync(
314309
self,
@@ -357,14 +352,13 @@ def get_user(session:Session,id:int):
357352
if executor is None or not isinstance(executor, (AsyncSession, AsyncConnection)):
358353
need_close = True
359354
executor = self.session_maker() if is_session else await self.engine.connect()
360-
result = await executor.run_sync(fn, *args, **kwargs)
361-
if on_close_pre:
362-
result = on_close_pre(result)
363-
if commit:
364-
await executor.commit()
365-
if need_close:
366-
await executor.close()
367-
return result
355+
async with ExecutorContextManager(executor, need_close = need_close) as executor:
356+
result = await executor.run_sync(fn, *args, **kwargs)
357+
if on_close_pre:
358+
result = on_close_pre(result)
359+
if commit:
360+
await executor.commit()
361+
return result
368362

369363
class Database(AbcAsyncDatabase):
370364
"""`sqlalchemy` synchronous database client
@@ -408,14 +402,13 @@ def execute(
408402
kw['bind_arguments'] = bind_arguments
409403
else:
410404
executor = self.engine.connect()
411-
result = executor.execute(statement, params, execution_options, **kw)
412-
if on_close_pre:
413-
result = on_close_pre(result)
414-
if commit and not isinstance(statement, Select):
415-
executor.commit()
416-
if need_close:
417-
executor.close()
418-
return result
405+
with ExecutorContextManager(executor, need_close = need_close) as executor:
406+
result = executor.execute(statement, params, execution_options, **kw)
407+
if on_close_pre:
408+
result = on_close_pre(result)
409+
if commit and not isinstance(statement, Select):
410+
executor.commit()
411+
return result
419412

420413
def scalar(
421414
self,
@@ -431,16 +424,15 @@ def scalar(
431424
executor = self.session_maker()
432425
else:
433426
executor = session
434-
result = executor.scalar(
435-
statement,
436-
params,
437-
execution_options = execution_options,
438-
bind_arguments = bind_arguments,
439-
**kw,
440-
)
441-
if session is None:
442-
executor.close()
443-
return result
427+
with ExecutorContextManager(executor, need_close = session is None) as executor:
428+
result = executor.scalar(
429+
statement,
430+
params,
431+
execution_options = execution_options,
432+
bind_arguments = bind_arguments,
433+
**kw,
434+
)
435+
return result
444436

445437
def scalars_all(
446438
self,
@@ -456,16 +448,15 @@ def scalars_all(
456448
executor = self.session_maker()
457449
else:
458450
executor = session
459-
result = executor.scalars(
460-
statement,
461-
params,
462-
execution_options = execution_options,
463-
bind_arguments = bind_arguments,
464-
**kw,
465-
).all()
466-
if session is None:
467-
executor.close()
468-
return result
451+
with ExecutorContextManager(executor, need_close = session is None) as executor:
452+
result = executor.scalars(
453+
statement,
454+
params,
455+
execution_options = execution_options,
456+
bind_arguments = bind_arguments,
457+
**kw,
458+
).all()
459+
return result
469460

470461
def get(
471462
self,
@@ -483,17 +474,16 @@ def get(
483474
executor = self.session_maker()
484475
else:
485476
executor = session
486-
result = executor.get(
487-
entity,
488-
ident,
489-
options = options,
490-
populate_existing = populate_existing,
491-
with_for_update = with_for_update,
492-
identity_token = identity_token,
493-
)
494-
if session is None:
495-
executor.close()
496-
return result
477+
with ExecutorContextManager(executor, need_close = session is None) as executor:
478+
result = executor.get(
479+
entity,
480+
ident,
481+
options = options,
482+
populate_existing = populate_existing,
483+
with_for_update = with_for_update,
484+
identity_token = identity_token,
485+
)
486+
return result
497487

498488
def delete(self, instance: Any) -> None:
499489
with self.session_maker() as session:
@@ -510,12 +500,11 @@ def save(
510500
executor = self.session_maker()
511501
else:
512502
executor = session
513-
executor.add_all(instances)
514-
executor.commit()
515-
if refresh:
516-
[executor.refresh(instance) for instance in instances]
517-
if session is None:
518-
executor.close()
503+
with ExecutorContextManager(executor, need_close = session is None) as executor:
504+
executor.add_all(instances)
505+
executor.commit()
506+
if refresh:
507+
[executor.refresh(instance) for instance in instances]
519508

520509
def run_sync(
521510
self,
@@ -531,11 +520,35 @@ def run_sync(
531520
if executor is None or not isinstance(executor, (Session, Connection)):
532521
need_close = True
533522
executor = self.session_maker() if is_session else self.engine.connect()
534-
result = fn(executor, *args, **kwargs)
535-
if on_close_pre:
536-
result = on_close_pre(result)
537-
if commit:
538-
executor.commit()
539-
if need_close:
540-
executor.close()
541-
return result
523+
with ExecutorContextManager(executor, need_close = need_close) as executor:
524+
result = fn(executor, *args, **kwargs)
525+
if on_close_pre:
526+
result = on_close_pre(result)
527+
if commit:
528+
executor.commit()
529+
return result
530+
531+
class ExecutorContextManager:
532+
"""Actuator context manager, optionally closing the executor"""
533+
534+
def __init__(
535+
self,
536+
executor: Union[Session, Connection, AsyncSession, AsyncConnection],
537+
need_close: bool = True
538+
):
539+
self.executor = executor
540+
self.need_close = need_close
541+
542+
def __enter__(self):
543+
return self.executor
544+
545+
def __exit__(self, exc_type, exc_val, exc_tb):
546+
if self.need_close:
547+
self.executor.close()
548+
549+
async def __aenter__(self):
550+
return self.executor
551+
552+
async def __aexit__(self, exc_type, exc_val, exc_tb):
553+
if self.need_close:
554+
await self.executor.close()

0 commit comments

Comments
 (0)