Skip to content

Commit 775ea86

Browse files
committed
feat: Add the refresh option to the db.save shortcut function.
1 parent 947308e commit 775ea86

4 files changed

Lines changed: 16 additions & 3 deletions

File tree

sqlalchemy_database/database.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ async def save(self, *instances: Any, refresh: bool = False) -> None:
271271
async with self.session_maker() as session:
272272
async with session.begin():
273273
session.add_all(instances)
274+
if refresh:
275+
[await session.refresh(instance) for instance in instances]
274276

275277
async def run_sync(
276278
self,
@@ -435,6 +437,8 @@ def save(self, *instances: Any, refresh: bool = False) -> None:
435437
with self.session_maker() as session:
436438
with session.begin():
437439
session.add_all(instances)
440+
if refresh:
441+
[session.refresh(instance) for instance in instances]
438442

439443
def run_sync(
440444
self,

tests/test_AbcAsyncDatabase.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import pytest
55
from sqlalchemy import insert, select, update, delete
66
from sqlalchemy.orm import Session
7-
87
from sqlalchemy_database import AsyncDatabase, Database
98
from tests.conftest import async_db, sync_db, Base, User
109

@@ -104,6 +103,10 @@ async def test_async_save(db, fake_users):
104103
await db.async_save(user2)
105104
u = await db.async_scalar(select(User).where(User.username == 'new_user2'))
106105
assert u.username == 'new_user2'
106+
# test refresh
107+
user3 = User(username='new_user3')
108+
await db.async_save(user3, refresh=True)
109+
assert user3.id
107110

108111

109112
async def test_async_run_sync(db, fake_users):

tests/test_AsyncDatabase.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,10 @@ async def test_save(fake_users):
111111
await db.save(user2)
112112
u = await db.scalar(select(User).where(User.username == 'new_user2'))
113113
assert u.username == 'new_user2'
114-
114+
# test refresh
115+
user3 = User(username='new_user3')
116+
await db.save(user3, refresh=True)
117+
assert user3.id
115118

116119
async def test_run_sync(fake_users):
117120
def delete_user(session: Session, instance: User):

tests/test_Database.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import pytest
66
from sqlalchemy import insert, select, update, delete
77
from sqlalchemy.orm import Session
8-
98
from tests.conftest import sync_db as db, Base, User
109

1110

@@ -111,6 +110,10 @@ def test_save(fake_users):
111110
db.save(user2)
112111
u = db.scalar(select(User).where(User.username == 'new_user2'))
113112
assert u.username == 'new_user2'
113+
# test refresh
114+
user3 = User(username='new_user3')
115+
db.save(user3, refresh=True)
116+
assert user3.id
114117

115118

116119
def test_run_sync(fake_users):

0 commit comments

Comments
 (0)