Skip to content

Commit c730478

Browse files
committed
feat: support custom executor.
1 parent 7aacd13 commit c730478

7 files changed

Lines changed: 425 additions & 326 deletions

File tree

sqlalchemy_database/_abc_async_database.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
_P = ParamSpec("_P")
1313
_R = TypeVar("_R")
1414

15-
1615
async def to_thread(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: # noqa: E303
1716
loop = asyncio.get_running_loop()
1817
ctx = contextvars.copy_context()
1918
func_call = functools.partial(ctx.run, func, *args, **kwargs)
2019
return await loop.run_in_executor(None, func_call)
2120

22-
23-
class AbcAsyncDatabase(metaclass=abc.ABCMeta):
21+
class AbcAsyncDatabase(metaclass = abc.ABCMeta):
2422

2523
def __init__(self) -> None:
2624
for func_name in ['execute', 'scalar', 'scalars_all', 'get', 'delete', 'save', 'run_sync']:

sqlalchemy_database/_abc_async_database.pyi

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import abc
22
from typing import Any, Optional, Mapping, Union, Sequence, Type, List, TypeVar, Callable
33

44
from sqlalchemy.engine import Result, Connection
5+
from sqlalchemy.ext.asyncio import AsyncSession, AsyncConnection
56
from sqlalchemy.orm import Session
67
from sqlalchemy.sql import Executable
78
from typing_extensions import ParamSpec, Concatenate
@@ -13,66 +14,75 @@ _R = TypeVar("_R")
1314
_ExecuteParams = Union[Mapping[Any, Any], Sequence[Mapping[Any, Any]]]
1415
_ExecuteOptions = Mapping[Any, Any]
1516

16-
17-
class AbcAsyncDatabase(metaclass=abc.ABCMeta):
17+
class AbcAsyncDatabase(metaclass = abc.ABCMeta):
1818
"""`sqlalchemy` asynchronous database abstract base class, not directly instantiated
1919
2020
"""
2121

2222
async def async_execute(
23-
self,
24-
statement: Executable,
25-
params: Optional[_ExecuteParams] = None,
26-
*,
27-
execution_options: Optional[_ExecuteOptions] = None,
28-
bind_arguments: Optional[Mapping[str, Any]] = None,
29-
commit: bool = True,
30-
on_close_pre: Callable[[Result], _T] = None,
31-
is_session: bool = True,
32-
**kw: Any,
23+
self,
24+
statement: Executable,
25+
params: Optional[_ExecuteParams] = None,
26+
*,
27+
execution_options: Optional[_ExecuteOptions] = None,
28+
bind_arguments: Optional[Mapping[str, Any]] = None,
29+
commit: bool = True,
30+
on_close_pre: Callable[[Result], _T] = None,
31+
is_session: bool = True,
32+
executor: Union[Session, Connection, AsyncSession, AsyncConnection, None] = None,
33+
**kw: Any,
3334
) -> Union[Result, _T]: ...
3435

3536
async def async_scalar(
36-
self,
37-
statement: Executable,
38-
params: Optional[_ExecuteParams] = None,
39-
*,
40-
execution_options: Optional[_ExecuteOptions] = None,
41-
bind_arguments: Optional[Mapping[str, Any]] = None,
42-
**kw: Any,
37+
self,
38+
statement: Executable,
39+
params: Optional[_ExecuteParams] = None,
40+
*,
41+
execution_options: Optional[_ExecuteOptions] = None,
42+
bind_arguments: Optional[Mapping[str, Any]] = None,
43+
session: Union[Session, AsyncSession, None] = None,
44+
**kw: Any,
4345
) -> Any: ...
4446

4547
async def async_scalars_all(
46-
self,
47-
statement: Executable,
48-
params: Optional[_ExecuteParams] = None,
49-
*,
50-
execution_options: Optional[_ExecuteOptions] = None,
51-
**kw: Any,
48+
self,
49+
statement: Executable,
50+
params: Optional[_ExecuteParams] = None,
51+
*,
52+
execution_options: Optional[_ExecuteOptions] = None,
53+
session: Union[Session, AsyncSession, None] = None,
54+
**kw: Any,
5255
) -> List[Any]: ...
5356

5457
async def async_get(
55-
self,
56-
entity: Type[_T],
57-
ident: Any,
58-
*,
59-
options: Optional[Sequence[Any]] = None,
60-
populate_existing: bool = False,
61-
with_for_update: Optional[Any] = None,
62-
identity_token: Optional[Any] = None,
63-
execution_options: Optional[_ExecuteOptions] = None,
58+
self,
59+
entity: Type[_T],
60+
ident: Any,
61+
*,
62+
options: Optional[Sequence[Any]] = None,
63+
populate_existing: bool = False,
64+
with_for_update: Optional[Any] = None,
65+
identity_token: Optional[Any] = None,
66+
execution_options: Optional[_ExecuteOptions] = None,
67+
session: Union[Session, AsyncSession, None] = None
6468
) -> Optional[_T]: ...
6569

6670
async def async_delete(self, instance: Any) -> None: ...
6771

68-
async def async_save(self, *instances: Any, refresh: bool = False) -> None: ...
72+
async def async_save(
73+
self,
74+
*instances: Any,
75+
refresh: bool = False,
76+
session: Union[Session, AsyncSession, None] = None
77+
) -> None: ...
6978

7079
async def async_run_sync(
71-
self,
72-
fn: Callable[[Concatenate[Union[Session, Connection], _P]], _T],
73-
*args: _P.args,
74-
commit: bool = True,
75-
on_close_pre: Callable[[_T], _R] = None,
76-
is_session: bool = True,
77-
**kwargs: _P.kwargs
80+
self,
81+
fn: Callable[[Concatenate[Union[Session, Connection], _P]], _T],
82+
*args: _P.args,
83+
commit: bool = True,
84+
on_close_pre: Callable[[_T], _R] = None,
85+
is_session: bool = True,
86+
executor: Union[Session, Connection, AsyncSession, AsyncConnection, None] = None,
87+
**kwargs: _P.kwargs
7888
) -> Union[_T, _R]: ...

0 commit comments

Comments
 (0)