Skip to content

Commit d5437ce

Browse files
committed
perf: update the run_sync and execute default parameters
1 parent 4cec0a0 commit d5437ce

7 files changed

Lines changed: 90 additions & 60 deletions

File tree

README.md

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ async def fast_execute():
102102

103103
# insert
104104
stmt = insert(User).values({'username': 'User-6', 'password': 'password-6'})
105-
result = await db.execute(stmt, commit=True)
105+
result = await db.execute(stmt)
106106

107107
# delete
108108
stmt = delete(User).where(User.id == 6)
109-
result = await db.execute(stmt, commit=True)
109+
result = await db.execute(stmt)
110110

111111
# scalar
112112
user = await db.scalar(select(User).where(User.id == 1))
@@ -123,7 +123,7 @@ async def fast_execute():
123123
await db.delete(user)
124124

125125
# run_sync
126-
await db.run_sync(Base.metadata.create_all)
126+
await db.run_sync(Base.metadata.create_all, is_session=False)
127127

128128
# session_maker
129129
async with db.session_maker() as session:
@@ -162,11 +162,11 @@ def fast_execute():
162162

163163
# insert
164164
stmt = insert(User).values({'username': 'User-6', 'password': 'password-6'})
165-
result = db.execute(stmt, commit=True)
165+
result = db.execute(stmt)
166166

167167
# delete
168168
stmt = delete(User).where(User.id == 6)
169-
result = db.execute(stmt, commit=True)
169+
result = db.execute(stmt)
170170

171171
# scalar
172172
user = db.scalar(select(User).where(User.id == 1))
@@ -183,7 +183,7 @@ def fast_execute():
183183
db.delete(user)
184184

185185
# run_sync
186-
db.run_sync(Base.metadata.create_all)
186+
db.run_sync(Base.metadata.create_all, is_session=False)
187187

188188
# session_maker
189189
with db.session_maker() as session:
@@ -215,19 +215,19 @@ from sqlalchemy_database import AsyncDatabase, Database
215215
async def fast_execute(db: Union[AsyncDatabase, Database]):
216216
# update
217217
stmt = update(User).where(User.id == 1).values({'username': 'new_user'})
218-
result = await db.async_execute(stmt, commit=True)
218+
result = await db.async_execute(stmt)
219219

220220
# select
221221
stmt = select(User).where(User.id == 1)
222222
user = await db.async_execute(stmt, on_close_pre=lambda r: r.scalar())
223223

224224
# insert
225225
stmt = insert(User).values({'username': 'User-6', 'password': 'password-6'})
226-
result = await db.async_execute(stmt, commit=True)
226+
result = await db.async_execute(stmt)
227227

228228
# delete
229229
stmt = delete(User).where(User.id == 6)
230-
result = await db.async_execute(stmt, commit=True)
230+
result = await db.async_execute(stmt)
231231

232232
# scalar
233233
user = await db.async_scalar(select(User).where(User.id == 1))
@@ -244,7 +244,7 @@ async def fast_execute(db: Union[AsyncDatabase, Database]):
244244
await db.async_delete(user)
245245

246246
# run_sync
247-
await db.async_run_sync(Base.metadata.create_all)
247+
await db.async_run_sync(Base.metadata.create_all, is_session=False)
248248

249249
```
250250

README.zh.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,19 @@ from sqlalchemy import insert, select, update, delete
9494
async def fast_execute():
9595
# update
9696
stmt = update(User).where(User.id == 1).values({'username': 'new_user'})
97-
result = await db.execute(stmt, commit=True)
97+
result = await db.execute(stmt)
9898

9999
# select
100100
stmt = select(User).where(User.id == 1)
101101
user = await db.execute(stmt, on_close_pre=lambda r: r.scalar())
102102

103103
# insert
104104
stmt = insert(User).values({'username': 'User-6', 'password': 'password-6'})
105-
result = await db.execute(stmt, commit=True)
105+
result = await db.execute(stmt)
106106

107107
# delete
108108
stmt = delete(User).where(User.id == 6)
109-
result = await db.execute(stmt, commit=True)
109+
result = await db.execute(stmt)
110110

111111
# scalar
112112
user = await db.scalar(select(User).where(User.id == 1))
@@ -123,7 +123,7 @@ async def fast_execute():
123123
await db.delete(user)
124124

125125
# run_sync
126-
await db.run_sync(Base.metadata.create_all)
126+
await db.run_sync(Base.metadata.create_all, is_session=False)
127127

128128
# session_maker
129129
async with db.session_maker() as session:
@@ -154,19 +154,19 @@ from sqlalchemy import insert, select, update, delete
154154
def fast_execute():
155155
# update
156156
stmt = update(User).where(User.id == 1).values({'username': 'new_user'})
157-
result = db.execute(stmt, commit=True)
157+
result = db.execute(stmt)
158158

159159
# select
160160
stmt = select(User).where(User.id == 1)
161161
user = db.execute(stmt, on_close_pre=lambda r: r.scalar())
162162

163163
# insert
164164
stmt = insert(User).values({'username': 'User-6', 'password': 'password-6'})
165-
result = db.execute(stmt, commit=True)
165+
result = db.execute(stmt)
166166

167167
# delete
168168
stmt = delete(User).where(User.id == 6)
169-
result = db.execute(stmt, commit=True)
169+
result = db.execute(stmt)
170170

171171
# scalar
172172
user = db.scalar(select(User).where(User.id == 1))
@@ -183,7 +183,7 @@ def fast_execute():
183183
db.delete(user)
184184

185185
# run_sync
186-
db.run_sync(Base.metadata.create_all)
186+
db.run_sync(Base.metadata.create_all, is_session=False)
187187

188188
# session_maker
189189
with db.session_maker() as session:
@@ -212,19 +212,19 @@ from sqlalchemy_database import AsyncDatabase, Database
212212
async def fast_execute(db: Union[AsyncDatabase, Database]):
213213
# update
214214
stmt = update(User).where(User.id == 1).values({'username': 'new_user'})
215-
result = await db.async_execute(stmt, commit=True)
215+
result = await db.async_execute(stmt)
216216

217217
# select
218218
stmt = select(User).where(User.id == 1)
219219
user = await db.async_execute(stmt, on_close_pre=lambda r: r.scalar())
220220

221221
# insert
222222
stmt = insert(User).values({'username': 'User-6', 'password': 'password-6'})
223-
result = await db.async_execute(stmt, commit=True)
223+
result = await db.async_execute(stmt)
224224

225225
# delete
226226
stmt = delete(User).where(User.id == 6)
227-
result = await db.async_execute(stmt, commit=True)
227+
result = await db.async_execute(stmt)
228228

229229
# scalar
230230
user = await db.async_scalar(select(User).where(User.id == 1))
@@ -241,7 +241,7 @@ async def fast_execute(db: Union[AsyncDatabase, Database]):
241241
await db.async_delete(user)
242242

243243
# run_sync
244-
await db.async_run_sync(Base.metadata.create_all)
244+
await db.async_run_sync(Base.metadata.create_all, is_session=False)
245245

246246
```
247247

sqlalchemy_database/_abc_async_database.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
2323
*,
2424
execution_options: Optional[_ExecuteOptions] = None,
2525
bind_arguments: Optional[Mapping[str, Any]] = None,
26-
commit: bool = False,
26+
commit: bool = True,
2727
on_close_pre: Callable[[Result], _T] = None,
28+
is_session: bool = True,
2829
**kw: Any,
2930
) -> Union[Result, _T]: ...
3031

@@ -67,6 +68,6 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
6768
*args: _P.args,
6869
commit: bool = True,
6970
on_close_pre: Callable[[_T], _R] = None,
70-
is_session: bool = False,
71+
is_session: bool = True,
7172
**kwargs: _P.kwargs
7273
) -> Union[_T,_R]: ...

sqlalchemy_database/database.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine
55
from sqlalchemy.future import Engine, create_engine
66
from sqlalchemy.orm import sessionmaker, Session
7-
from sqlalchemy.sql import Executable
8-
from typing_extensions import Concatenate, ParamSpec
9-
7+
from sqlalchemy.sql import Executable, Select
108
from sqlalchemy_database._abc_async_database import AbcAsyncDatabase
9+
from typing_extensions import Concatenate, ParamSpec
1110

1211
_P = ParamSpec("_P")
1312
_T = TypeVar("_T")
@@ -43,16 +42,23 @@ async def execute(
4342
*,
4443
execution_options: Optional[_ExecuteOptions] = None,
4544
bind_arguments: Optional[Mapping[str, Any]] = None,
46-
commit: bool = False,
45+
commit: bool = True,
4746
on_close_pre: Callable[[Result], _T] = None,
47+
is_session: bool = True,
4848
**kw: Any,
4949
) -> Union[Result, _T]:
50-
async with self.session_maker() as session:
51-
result = await session.execute(statement, params, execution_options, bind_arguments, **kw)
50+
if is_session:
51+
maker = self.session_maker
52+
kw['bind_arguments'] = bind_arguments
53+
else:
54+
maker = self.engine.connect
55+
async with maker() as conn:
56+
57+
result = await conn.execute(statement, params, execution_options, **kw)
5258
if on_close_pre:
5359
result = on_close_pre(result)
54-
if commit:
55-
await session.commit()
60+
if commit and not isinstance(statement, Select):
61+
await conn.commit()
5662
return result
5763

5864
async def scalar(
@@ -122,9 +128,9 @@ async def run_sync(
122128
*args: _P.args,
123129
commit: bool = True,
124130
on_close_pre: Callable[[_T], _R] = None,
125-
is_session: bool = False,
131+
is_session: bool = True,
126132
**kwargs: _P.kwargs
127-
) -> Union[_T,_R]:
133+
) -> Union[_T, _R]:
128134
maker = self.session_maker if is_session else self.engine.connect
129135
async with maker() as conn:
130136
result = await conn.run_sync(fn, *args, **kwargs)
@@ -160,16 +166,22 @@ def execute(
160166
*,
161167
execution_options: Optional[_ExecuteOptions] = None,
162168
bind_arguments: Optional[Mapping[str, Any]] = None,
163-
commit: bool = False,
169+
commit: bool = True,
164170
on_close_pre: Callable[[Result], _T] = None,
171+
is_session: bool = True,
165172
**kw: Any,
166173
) -> Union[Result, _T]:
167-
with self.session_maker() as session:
168-
result = session.execute(statement, params, execution_options, bind_arguments, **kw)
174+
if is_session:
175+
maker = self.session_maker
176+
kw['bind_arguments'] = bind_arguments
177+
else:
178+
maker = self.engine.connect
179+
with maker() as conn:
180+
result = conn.execute(statement, params, execution_options, **kw)
169181
if on_close_pre:
170182
result = on_close_pre(result)
171-
if commit:
172-
session.commit()
183+
if commit and not isinstance(statement, Select):
184+
conn.commit()
173185
return result
174186

175187
def scalar(
@@ -240,7 +252,7 @@ def run_sync(
240252
*args: _P.args,
241253
commit: bool = True,
242254
on_close_pre: Callable[[_T], _R] = None,
243-
is_session: bool = False,
255+
is_session: bool = True,
244256
**kwargs: _P.kwargs
245257
) -> Union[_T, _R]:
246258
maker = self.session_maker if is_session else self.engine.connect

tests/test_AbcAsyncDatabase.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ def db(request) -> Union[Database, AsyncDatabase]:
1616

1717
@pytest.fixture()
1818
async def prepare_database(db) -> AsyncGenerator[None, None]:
19-
await db.async_run_sync(Base.metadata.create_all)
19+
await db.async_run_sync(Base.metadata.create_all, is_session=False)
2020
yield
21-
await db.async_run_sync(Base.metadata.drop_all)
21+
await db.async_run_sync(Base.metadata.drop_all, is_session=False)
2222

2323

2424
@pytest.fixture
@@ -30,14 +30,14 @@ async def fake_users(db, prepare_database) -> List[dict]:
3030
"create_time": datetime.datetime.strptime(f"2022-01-0{i} 00:00:00", "%Y-%m-%d %H:%M:%S")
3131
} for i in range(1, 6)
3232
]
33-
await db.async_execute(insert(User).values(data), commit=True)
33+
await db.async_execute(insert(User).values(data))
3434
return data
3535

3636

3737
async def test_async_execute(db, fake_users):
3838
# update
3939
stmt = update(User).where(User.id == 1).values({'username': 'new_user'})
40-
result = await db.async_execute(stmt, commit=True)
40+
result = await db.async_execute(stmt)
4141
assert result.rowcount == 1
4242
# select
4343
user = await db.async_execute(select(User).where(User.id == 1), on_close_pre=lambda r: r.scalar())
@@ -48,14 +48,20 @@ async def test_async_execute(db, fake_users):
4848
'username': 'User-6',
4949
'password': 'password_6'
5050
})
51-
result = await db.async_execute(stmt, commit=True)
51+
result = await db.async_execute(stmt)
5252
assert result.rowcount == 1
5353
# delete
5454
stmt = delete(User).where(User.id == 6)
55-
result = await db.async_execute(stmt, commit=True)
55+
result = await db.async_execute(stmt)
5656
assert result.rowcount == 1
5757

5858

59+
async def test_async_execute_connection(db, fake_users):
60+
# Select
61+
user = await db.async_execute(select(User).where(User.id == 1), is_session=False, on_close_pre=lambda r: r.one())
62+
assert user.id == 1
63+
64+
5965
async def test_async_scalar(db, fake_users):
6066
user = await db.async_scalar(select(User).where(User.id == 1))
6167
assert user.id == 1
@@ -91,13 +97,13 @@ def delete_user(session: Session, instance: User):
9197

9298
user = await db.async_get(User, 1)
9399
assert user.id == 1
94-
await db.async_run_sync(delete_user, user, is_session=True)
100+
await db.async_run_sync(delete_user, user)
95101
user = await db.async_get(User, 1)
96102
assert user is None
97103

98104
# test on_close_pre
99105
def get_user(session: Session, user_id: int):
100106
return session.get(User, user_id)
101107

102-
user_id = await db.async_run_sync(get_user, 2, is_session=True, on_close_pre=lambda r: r.id)
108+
user_id = await db.async_run_sync(get_user, 2, on_close_pre=lambda r: r.id)
103109
assert user_id == 2

0 commit comments

Comments
 (0)