Skip to content

Commit c07ff1c

Browse files
committed
feat: add db.save shortcut function
1 parent 111d945 commit c07ff1c

8 files changed

Lines changed: 92 additions & 11 deletions

File tree

README.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ from sqlalchemy import insert, select, update, delete
9494
async def fast_execute():
9595
# update
9696
stmt = update(User).where(User.id == 1).values({'username': 'new_user'})
97-
result = await db.execute(stmt, commit=True)
97+
result = await db.execute(stmt)
9898

9999
# select
100100
stmt = select(User).where(User.id == 1)
@@ -121,7 +121,11 @@ async def fast_execute():
121121
# delete
122122
user = User(id=1, name='test')
123123
await db.delete(user)
124-
124+
125+
# save(insert or update)
126+
user = User(name='new_user')
127+
await db.save(user)
128+
125129
# run_sync
126130
await db.run_sync(Base.metadata.create_all, is_session=False)
127131

@@ -154,7 +158,7 @@ from sqlalchemy import insert, select, update, delete
154158
def fast_execute():
155159
# update
156160
stmt = update(User).where(User.id == 1).values({'username': 'new_user'})
157-
result = db.execute(stmt, commit=True)
161+
result = db.execute(stmt)
158162

159163
# select
160164
stmt = select(User).where(User.id == 1)
@@ -181,7 +185,11 @@ def fast_execute():
181185
# delete
182186
user = User(id=1, name='test')
183187
db.delete(user)
184-
188+
189+
# save(insert or update)
190+
user = User(name='new_user')
191+
db.save(user)
192+
185193
# run_sync
186194
db.run_sync(Base.metadata.create_all, is_session=False)
187195

@@ -242,7 +250,11 @@ async def fast_execute(db: Union[AsyncDatabase, Database]):
242250
# delete
243251
user = User(id=1, name='test')
244252
await db.async_delete(user)
245-
253+
254+
# save(insert or update)
255+
user = User(name='new_user')
256+
await db.async_save(user)
257+
246258
# run_sync
247259
await db.async_run_sync(Base.metadata.create_all, is_session=False)
248260

README.zh.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ async def fast_execute():
122122
user = User(id=1, name='test')
123123
await db.delete(user)
124124

125+
# save(insert or update)
126+
user = User(name='new_user')
127+
await db.save(user)
128+
125129
# run_sync
126130
await db.run_sync(Base.metadata.create_all, is_session=False)
127131

@@ -182,6 +186,10 @@ def fast_execute():
182186
user = User(id=1, name='test')
183187
db.delete(user)
184188

189+
# save(insert or update)
190+
user = User(name='new_user')
191+
db.save(user)
192+
185193
# run_sync
186194
db.run_sync(Base.metadata.create_all, is_session=False)
187195

@@ -239,7 +247,11 @@ async def fast_execute(db: Union[AsyncDatabase, Database]):
239247
# delete
240248
user = User(id=1, name='test')
241249
await db.async_delete(user)
242-
250+
251+
# save(insert or update)
252+
user = User(name='new_user')
253+
await db.async_save(user)
254+
243255
# run_sync
244256
await db.async_run_sync(Base.metadata.create_all, is_session=False)
245257

sqlalchemy_database/_abc_async_database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async def to_thread(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs)
2323
class AbcAsyncDatabase(metaclass=abc.ABCMeta):
2424

2525
def __init__(self) -> None:
26-
for func_name in ['execute', 'scalar', 'scalars_all', 'get', 'delete', 'run_sync']:
26+
for func_name in ['execute', 'scalar', 'scalars_all', 'get', 'delete', 'save', 'run_sync']:
2727
func = getattr(self, func_name)
2828
if not asyncio.iscoroutinefunction(func):
2929
func = functools.partial(to_thread, func)

sqlalchemy_database/_abc_async_database.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
6262

6363
async def async_delete(self, instance: Any) -> None: ...
6464

65+
async def async_save(self, *instances: Any) -> None: ...
66+
6567
async def async_run_sync(
6668
self,
6769
fn: Callable[[Concatenate[Union[Session, Connection], _P]], _T],
@@ -70,4 +72,4 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
7072
on_close_pre: Callable[[_T], _R] = None,
7173
is_session: bool = True,
7274
**kwargs: _P.kwargs
73-
) -> Union[_T,_R]: ...
75+
) -> Union[_T, _R]: ...

sqlalchemy_database/database.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
from sqlalchemy.future import Engine, create_engine
66
from sqlalchemy.orm import sessionmaker, Session
77
from sqlalchemy.sql import Executable, Select
8-
from sqlalchemy_database._abc_async_database import AbcAsyncDatabase
98
from typing_extensions import Concatenate, ParamSpec
109

10+
from sqlalchemy_database._abc_async_database import AbcAsyncDatabase
11+
1112
_P = ParamSpec("_P")
1213
_T = TypeVar("_T")
1314
_R = TypeVar("_R")
@@ -53,8 +54,7 @@ async def execute(
5354
else:
5455
maker = self.engine.connect
5556
async with maker() as conn:
56-
57-
result = await conn.execute(statement, params, execution_options, **kw)
57+
result = await conn.execute(statement, params, execution_options, **kw) # type:ignore
5858
if on_close_pre:
5959
result = on_close_pre(result)
6060
if commit and not isinstance(statement, Select):
@@ -122,6 +122,11 @@ async def delete(self, instance: Any) -> None:
122122
async with session.begin():
123123
await session.delete(instance)
124124

125+
async def save(self, *instances: Any) -> None:
126+
async with self.session_maker() as session:
127+
async with session.begin():
128+
session.add_all(instances)
129+
125130
async def run_sync(
126131
self,
127132
fn: Callable[[Concatenate[Union[Session, Connection], _P]], _T],
@@ -246,6 +251,11 @@ def delete(self, instance: Any) -> None:
246251
with session.begin():
247252
session.delete(instance)
248253

254+
def save(self, *instances: Any) -> None:
255+
with self.session_maker() as session:
256+
with session.begin():
257+
session.add_all(instances)
258+
249259
def run_sync(
250260
self,
251261
fn: Callable[[Concatenate[Union[Session, Connection], _P]], _T],

tests/test_AbcAsyncDatabase.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,21 @@ async def test_async_delete(db, fake_users):
9191
assert user is None
9292

9393

94+
async def test_async_save(db, fake_users):
95+
# test update
96+
user = await db.async_get(User, 1)
97+
assert user.id == 1
98+
user.username = 'new_user'
99+
await db.async_save(user)
100+
user = await db.async_get(User, 1)
101+
assert user.username == 'new_user'
102+
# test insert
103+
user2 = User(username='new_user2')
104+
await db.async_save(user2)
105+
u = await db.async_scalar(select(User).where(User.username == 'new_user2'))
106+
assert u.username == 'new_user2'
107+
108+
94109
async def test_async_run_sync(db, fake_users):
95110
def delete_user(session: Session, instance: User):
96111
session.delete(instance)

tests/test_AsyncDatabase.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,21 @@ async def test_delete(fake_users):
9898
assert user is None
9999

100100

101+
async def test_save(fake_users):
102+
# test update
103+
user = await db.get(User, 1)
104+
assert user.id == 1
105+
user.username = 'new_user'
106+
await db.save(user)
107+
user = await db.get(User, 1)
108+
assert user.username == 'new_user'
109+
# test insert
110+
user2 = User(username='new_user2')
111+
await db.save(user2)
112+
u = await db.scalar(select(User).where(User.username == 'new_user2'))
113+
assert u.username == 'new_user2'
114+
115+
101116
async def test_run_sync(fake_users):
102117
def delete_user(session: Session, instance: User):
103118
session.delete(instance)

tests/test_Database.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,21 @@ def test_delete(fake_users):
9898
assert user is None
9999

100100

101+
def test_save(fake_users):
102+
# test update
103+
user = db.get(User, 1)
104+
assert user.id == 1
105+
user.username = 'new_user'
106+
db.save(user)
107+
user = db.get(User, 1)
108+
assert user.username == 'new_user'
109+
# test insert
110+
user2 = User(username='new_user2')
111+
db.save(user2)
112+
u = db.scalar(select(User).where(User.username == 'new_user2'))
113+
assert u.username == 'new_user2'
114+
115+
101116
def test_run_sync(fake_users):
102117
def delete_user(session: Session, instance: User):
103118
session.delete(instance)

0 commit comments

Comments
 (0)