Skip to content

Commit b69461e

Browse files
committed
feat: Add the run_sync method on_close_pre hook
1 parent 99b9c1e commit b69461e

2 files changed

Lines changed: 41 additions & 16 deletions

File tree

sqlalchemy_database/_abc_async_database.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ from typing_extensions import ParamSpec, Concatenate
88

99
_P = ParamSpec("_P")
1010
_T = TypeVar("_T")
11+
_R = TypeVar("_R")
1112

1213
_ExecuteParams = Union[Mapping[Any, Any], Sequence[Mapping[Any, Any]]]
1314
_ExecuteOptions = Mapping[Any, Any]
@@ -19,6 +20,7 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
1920
self,
2021
statement: Executable,
2122
params: Optional[_ExecuteParams] = None,
23+
*,
2224
execution_options: Optional[_ExecuteOptions] = None,
2325
bind_arguments: Optional[Mapping[str, Any]] = None,
2426
commit: bool = False,
@@ -30,6 +32,7 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
3032
self,
3133
statement: Executable,
3234
params: Optional[_ExecuteParams] = None,
35+
*,
3336
execution_options: Optional[_ExecuteOptions] = None,
3437
bind_arguments: Optional[Mapping[str, Any]] = None,
3538
**kw: Any,
@@ -39,6 +42,7 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
3942
self,
4043
statement: Executable,
4144
params: Optional[_ExecuteParams] = None,
45+
*,
4246
execution_options: Optional[_ExecuteOptions] = None,
4347
**kw: Any,
4448
) -> List[Any]: ...
@@ -47,6 +51,7 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
4751
self,
4852
entity: Type[_T],
4953
ident: Any,
54+
*,
5055
options: Optional[Sequence[Any]] = None,
5156
populate_existing: bool = False,
5257
with_for_update: Optional[Any] = None,
@@ -60,6 +65,8 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
6065
self,
6166
fn: Callable[[Concatenate[Union[Session, Connection], _P]], _T],
6267
*args: _P.args,
68+
commit: bool = True,
69+
on_close_pre: Callable[[_T], _R] = None,
6370
is_session: bool = False,
6471
**kwargs: _P.kwargs
65-
) -> _T: ...
72+
) -> Union[_T,_R]: ...

sqlalchemy_database/database.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
from sqlalchemy_database._abc_async_database import AbcAsyncDatabase
1111

12-
_T = TypeVar("_T")
1312
_P = ParamSpec("_P")
13+
_T = TypeVar("_T")
14+
_R = TypeVar("_R")
15+
1416
_ExecuteParams = Union[Mapping[Any, Any], Sequence[Mapping[Any, Any]]]
1517
_ExecuteOptions = Mapping[Any, Any]
1618

@@ -38,6 +40,7 @@ async def execute(
3840
self,
3941
statement: Executable,
4042
params: Optional[_ExecuteParams] = None,
43+
*,
4144
execution_options: Optional[_ExecuteOptions] = None,
4245
bind_arguments: Optional[Mapping[str, Any]] = None,
4346
commit: bool = False,
@@ -56,6 +59,7 @@ async def scalar(
5659
self,
5760
statement: Executable,
5861
params: Optional[_ExecuteParams] = None,
62+
*,
5963
execution_options: Optional[_ExecuteOptions] = None,
6064
bind_arguments: Optional[Mapping[str, Any]] = None,
6165
**kw: Any,
@@ -73,6 +77,7 @@ async def scalars_all(
7377
self,
7478
statement: Executable,
7579
params: Optional[_ExecuteParams] = None,
80+
*,
7681
execution_options: Optional[_ExecuteOptions] = None,
7782
**kw: Any,
7883
) -> List[Any]:
@@ -89,6 +94,7 @@ async def get(
8994
self,
9095
entity: Type[_T],
9196
ident: Any,
97+
*,
9298
options: Optional[Sequence[Any]] = None,
9399
populate_existing: bool = False,
94100
with_for_update: Optional[Any] = None,
@@ -114,15 +120,19 @@ async def run_sync(
114120
self,
115121
fn: Callable[[Concatenate[Union[Session, Connection], _P]], _T],
116122
*args: _P.args,
123+
commit: bool = True,
124+
on_close_pre: Callable[[_T], _R] = None,
117125
is_session: bool = False,
118126
**kwargs: _P.kwargs
119-
) -> _T:
120-
if is_session:
121-
async with self.session_maker() as session:
122-
async with session.begin():
123-
return await session.run_sync(fn, *args, **kwargs)
124-
async with self.engine.begin() as conn:
125-
return await conn.run_sync(fn, *args, **kwargs)
127+
) -> Union[_T,_R]:
128+
maker = self.session_maker if is_session else self.engine.connect
129+
async with maker() as conn:
130+
result = await conn.run_sync(fn, *args, **kwargs)
131+
if commit:
132+
await conn.commit()
133+
if on_close_pre:
134+
result = on_close_pre(result)
135+
return result
126136

127137

128138
class Database(AbcAsyncDatabase):
@@ -147,6 +157,7 @@ def execute(
147157
self,
148158
statement: Executable,
149159
params: Optional[_ExecuteParams] = None,
160+
*,
150161
execution_options: Optional[_ExecuteOptions] = None,
151162
bind_arguments: Optional[Mapping[str, Any]] = None,
152163
commit: bool = False,
@@ -165,6 +176,7 @@ def scalar(
165176
self,
166177
statement: Executable,
167178
params: Optional[_ExecuteParams] = None,
179+
*,
168180
execution_options: Optional[_ExecuteOptions] = None,
169181
bind_arguments: Optional[Mapping[str, Any]] = None,
170182
**kw: Any,
@@ -182,6 +194,7 @@ def scalars_all(
182194
self,
183195
statement: Executable,
184196
params: Optional[_ExecuteParams] = None,
197+
*,
185198
execution_options: Optional[_ExecuteOptions] = None,
186199
bind_arguments: Optional[Mapping[str, Any]] = None,
187200
**kw: Any,
@@ -199,6 +212,7 @@ def get(
199212
self,
200213
entity: Type[_T],
201214
ident: Any,
215+
*,
202216
options: Optional[Sequence[Any]] = None,
203217
populate_existing: bool = False,
204218
with_for_update: Optional[Any] = None,
@@ -224,12 +238,16 @@ def run_sync(
224238
self,
225239
fn: Callable[[Concatenate[Union[Session, Connection], _P]], _T],
226240
*args: _P.args,
241+
commit: bool = True,
242+
on_close_pre: Callable[[_T], _R] = None,
227243
is_session: bool = False,
228244
**kwargs: _P.kwargs
229-
) -> _T:
230-
if is_session:
231-
with self.session_maker() as session:
232-
with session.begin():
233-
return fn(session, *args, **kwargs)
234-
with self.engine.begin() as conn:
235-
return fn(conn, *args, **kwargs)
245+
) -> Union[_T, _R]:
246+
maker = self.session_maker if is_session else self.engine.connect
247+
with maker() as conn:
248+
result = fn(conn, *args, **kwargs)
249+
if commit:
250+
conn.commit()
251+
if on_close_pre:
252+
result = on_close_pre(result)
253+
return result

0 commit comments

Comments
 (0)