Skip to content

Commit 03cf5db

Browse files
committed
style: Extract the executor maker method
1 parent ed062d4 commit 03cf5db

1 file changed

Lines changed: 66 additions & 133 deletions

File tree

sqlalchemy_database/database.py

Lines changed: 66 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,27 @@ async def session_generator(self) -> AsyncGenerator[AsyncSession, Any]:
126126
async with self.session_maker() as session:
127127
yield session
128128

129+
async def _executor_maker(
130+
self, executor: Union[AsyncSession, AsyncConnection, None] = None, is_session: bool = True
131+
) -> "ExecutorContextManager":
132+
"""Create an `AsyncSession` or `AsyncConnection` context manager.
133+
If the executor is specified, it will be returned directly.
134+
If not,but the `self.session` is not None, it will be returned directly.
135+
Otherwise, it will be created, and closed after the context manager exits.
136+
"""
137+
need_close = False
138+
if executor is None or not isinstance(executor, (AsyncSession, AsyncConnection)):
139+
need_close = True
140+
if is_session:
141+
executor = self.session
142+
if executor is None:
143+
executor = self.session_maker()
144+
else:
145+
need_close = False
146+
else:
147+
executor = await self.engine.connect()
148+
return ExecutorContextManager(executor, need_close=need_close)
149+
129150
async def execute(
130151
self,
131152
statement: Executable,
@@ -183,19 +204,9 @@ async def execute(
183204
cannot be called again after the connection is closed.
184205
185206
"""
186-
need_close = False
187-
if executor is None or not isinstance(executor, (AsyncSession, AsyncConnection)):
188-
need_close = True
189-
if is_session:
190-
executor = self.session
191-
if executor is None:
192-
executor = self.session_maker()
193-
else:
194-
need_close = False
195-
kw["bind_arguments"] = bind_arguments
196-
else:
197-
executor = await self.engine.connect()
198-
async with ExecutorContextManager(executor, need_close=need_close) as executor:
207+
if is_session:
208+
kw["bind_arguments"] = bind_arguments
209+
async with await self._executor_maker(executor, is_session) as executor:
199210
result = await executor.execute(statement, params, execution_options, **kw) # type:ignore
200211
if on_close_pre:
201212
result = on_close_pre(result)
@@ -219,13 +230,7 @@ async def scalar(
219230
Usage and parameters are the same as that of :meth:`_orm.Session.execute`;
220231
the return result is a scalar Python value.
221232
"""
222-
need_close = False
223-
if session is None or not isinstance(session, AsyncSession):
224-
session = self.session
225-
if session is None:
226-
need_close = True
227-
session = self.session_maker()
228-
async with ExecutorContextManager(session, need_close=need_close) as session:
233+
async with await self._executor_maker(session) as session:
229234
result = await session.scalar(
230235
statement,
231236
params,
@@ -250,13 +255,7 @@ async def scalars_all(
250255
Usage and parameters are the same as that of :meth:`_orm.Session.execute`;
251256
the return result is a list of scalar Python value.
252257
"""
253-
need_close = False
254-
if session is None or not isinstance(session, AsyncSession):
255-
session = self.session
256-
if session is None:
257-
need_close = True
258-
session = self.session_maker()
259-
async with ExecutorContextManager(session, need_close=need_close) as session:
258+
async with await self._executor_maker(session) as session:
260259
result = (
261260
await session.scalars(
262261
statement,
@@ -319,13 +318,7 @@ async def get(
319318
)
320319
```
321320
"""
322-
need_close = False
323-
if session is None or not isinstance(session, AsyncSession):
324-
session = self.session
325-
if session is None:
326-
need_close = True
327-
session = self.session_maker()
328-
async with ExecutorContextManager(session, need_close=need_close) as session:
321+
async with await self._executor_maker(session) as session:
329322
result = await session.get(
330323
entity,
331324
ident,
@@ -338,13 +331,9 @@ async def get(
338331

339332
async def delete(self, instance: Any) -> None:
340333
"""Deletes an instance object."""
341-
if self.session is not None:
342-
await self.session.delete(instance)
343-
await self.session.commit()
344-
else:
345-
async with self.session_maker() as session:
346-
async with session.begin():
347-
await session.delete(instance)
334+
async with await self._executor_maker() as session:
335+
await session.delete(instance)
336+
await session.commit()
348337

349338
async def save(self, *instances: Any, refresh: bool = False, session: Optional[AsyncSession] = None) -> None:
350339
"""
@@ -354,13 +343,7 @@ async def save(self, *instances: Any, refresh: bool = False, session: Optional[A
354343
Args:
355344
session: If not specified, an `AsyncSession` is created.
356345
"""
357-
need_close = False
358-
if session is None or not isinstance(session, AsyncSession):
359-
session = self.session
360-
if session is None:
361-
need_close = True
362-
session = self.session_maker()
363-
async with ExecutorContextManager(session, need_close=need_close) as session:
346+
async with await self._executor_maker(session) as session:
364347
session.add_all(instances)
365348
await session.commit()
366349
if refresh:
@@ -409,17 +392,7 @@ def get_user(session:Session,id:int):
409392
callable should only call into SQLAlchemy's asyncio database
410393
APIs which will be properly adapted to the greenlet context.
411394
"""
412-
need_close = False
413-
if executor is None or not isinstance(executor, (AsyncSession, AsyncConnection)):
414-
if is_session:
415-
executor = self.session
416-
if executor is None:
417-
executor = self.session_maker()
418-
need_close = True
419-
else:
420-
executor = await self.engine.connect()
421-
need_close = True
422-
async with ExecutorContextManager(executor, need_close=need_close) as executor:
395+
async with await self._executor_maker(executor, is_session) as executor:
423396
result = await executor.run_sync(fn, *args, **kwargs)
424397
if on_close_pre:
425398
result = on_close_pre(result)
@@ -440,13 +413,7 @@ async def refresh(self, instance, attribute_names=None, with_for_update=None, se
440413
Supersedes the :paramref:`.Session.refresh.lockmode` parameter.
441414
session: If not specified, an `AsyncSession` is created.
442415
"""
443-
need_close = False
444-
if session is None or not isinstance(session, AsyncSession):
445-
session = self.session
446-
if session is None:
447-
need_close = True
448-
session = self.session_maker()
449-
async with ExecutorContextManager(session, need_close=need_close) as session:
416+
async with await self._executor_maker(session) as session:
450417
await session.refresh(instance, attribute_names, with_for_update)
451418

452419

@@ -479,6 +446,27 @@ def session_generator(self) -> Generator[Session, Any, None]:
479446
with self.session_maker() as session:
480447
yield session
481448

449+
def _executor_maker(
450+
self, executor: Union[Session, Connection, None] = None, is_session: bool = True
451+
) -> "ExecutorContextManager":
452+
"""Create an `Session` or `Connection` context manager.
453+
If the executor is specified, it will be returned directly.
454+
If not,but the `self.session` is not None, it will be returned directly.
455+
Otherwise, it will be created, and closed after the context manager exits.
456+
"""
457+
need_close = False
458+
if executor is None or not isinstance(executor, (Session, Connection)):
459+
need_close = True
460+
if is_session:
461+
executor = self.session
462+
if executor is None:
463+
executor = self.session_maker()
464+
else:
465+
need_close = False
466+
else:
467+
executor = self.engine.connect()
468+
return ExecutorContextManager(executor, need_close=need_close)
469+
482470
def execute(
483471
self,
484472
statement: Executable,
@@ -492,19 +480,9 @@ def execute(
492480
executor: Union[Session, Connection, None] = None,
493481
**kw: Any,
494482
) -> Union[Result, _T]:
495-
need_close = False
496-
if executor is None or not isinstance(executor, (Session, Connection)):
497-
need_close = True
498-
if is_session:
499-
executor = self.session
500-
if executor is None:
501-
executor = self.session_maker()
502-
else:
503-
need_close = False
504-
kw["bind_arguments"] = bind_arguments
505-
else:
506-
executor = self.engine.connect()
507-
with ExecutorContextManager(executor, need_close=need_close) as executor:
483+
if is_session:
484+
kw["bind_arguments"] = bind_arguments
485+
with self._executor_maker(executor, is_session) as executor:
508486
result = executor.execute(statement, params, execution_options, **kw)
509487
if on_close_pre:
510488
result = on_close_pre(result)
@@ -522,13 +500,7 @@ def scalar(
522500
session: Optional[Session] = None,
523501
**kw: Any,
524502
) -> Any:
525-
need_close = False
526-
if session is None or not isinstance(session, Session):
527-
session = self.session
528-
if session is None:
529-
need_close = True
530-
session = self.session_maker()
531-
with ExecutorContextManager(session, need_close=need_close) as session:
503+
with self._executor_maker(session) as session:
532504
result = session.scalar(
533505
statement,
534506
params,
@@ -548,13 +520,7 @@ def scalars_all(
548520
session: Optional[Session] = None,
549521
**kw: Any,
550522
) -> List[Any]:
551-
need_close = False
552-
if session is None or not isinstance(session, Session):
553-
session = self.session
554-
if session is None:
555-
need_close = True
556-
session = self.session_maker()
557-
with ExecutorContextManager(session, need_close=need_close) as session:
523+
with self._executor_maker(session) as session:
558524
result = session.scalars(
559525
statement,
560526
params,
@@ -576,13 +542,7 @@ def get(
576542
execution_options: Optional[_ExecuteOptions] = None,
577543
session: Optional[Session] = None,
578544
) -> Optional[_T]:
579-
need_close = False
580-
if session is None or not isinstance(session, Session):
581-
session = self.session
582-
if session is None:
583-
need_close = True
584-
session = self.session_maker()
585-
with ExecutorContextManager(session, need_close=need_close) as session:
545+
with self._executor_maker(session) as session:
586546
result = session.get(
587547
entity,
588548
ident,
@@ -594,22 +554,12 @@ def get(
594554
return result
595555

596556
def delete(self, instance: Any) -> None:
597-
if self.session is not None:
598-
self.session.delete(instance)
599-
self.session.commit()
600-
else:
601-
with self.session_maker() as session:
602-
with session.begin():
603-
session.delete(instance)
557+
with self._executor_maker() as session:
558+
session.delete(instance)
559+
session.commit()
604560

605561
def save(self, *instances: Any, refresh: bool = False, session: Optional[Session] = None) -> None:
606-
need_close = False
607-
if session is None or not isinstance(session, Session):
608-
session = self.session
609-
if session is None:
610-
need_close = True
611-
session = self.session_maker()
612-
with ExecutorContextManager(session, need_close=need_close) as session:
562+
with self._executor_maker(session) as session:
613563
session.add_all(instances)
614564
session.commit()
615565
if refresh:
@@ -625,18 +575,7 @@ def run_sync(
625575
executor: Union[Session, Connection, None] = None,
626576
**kwargs: _P.kwargs,
627577
) -> Union[_T, _R]:
628-
need_close = False
629-
if executor is None or not isinstance(executor, (Session, Connection)):
630-
need_close = True
631-
if is_session:
632-
executor = self.session
633-
if executor is None:
634-
executor = self.session_maker()
635-
else:
636-
need_close = False
637-
else:
638-
executor = self.engine.connect()
639-
with ExecutorContextManager(executor, need_close=need_close) as executor:
578+
with self._executor_maker(executor, is_session) as executor:
640579
result = fn(executor, *args, **kwargs)
641580
if on_close_pre:
642581
result = on_close_pre(result)
@@ -645,13 +584,7 @@ def run_sync(
645584
return result
646585

647586
def refresh(self, instance, attribute_names=None, with_for_update=None, session: Optional[Session] = None):
648-
need_close = False
649-
if session is None or not isinstance(session, Session):
650-
session = self.session
651-
if session is None:
652-
need_close = True
653-
session = self.session_maker()
654-
with ExecutorContextManager(session, need_close=need_close) as session:
587+
with self._executor_maker(session) as session:
655588
session.refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update)
656589

657590

0 commit comments

Comments
 (0)