Skip to content

Commit cf441a4

Browse files
committed
feat: Add refresh method
1 parent c2a157c commit cf441a4

3 files changed

Lines changed: 36 additions & 1 deletion

File tree

sqlalchemy_database/_abc_async_database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ async def to_thread(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs)
2222

2323
class AbcAsyncDatabase(metaclass=abc.ABCMeta): # noqa: B024
2424
def __init__(self) -> None:
25-
for func_name in ["execute", "scalar", "scalars_all", "get", "delete", "save", "run_sync"]:
25+
for func_name in ["execute", "scalar", "scalars_all", "get", "delete", "save", "run_sync", "refresh"]:
2626
func = getattr(self, func_name)
2727
if not asyncio.iscoroutinefunction(func):
2828
func = functools.partial(to_thread, func)

sqlalchemy_database/_abc_async_database.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
9393
executor: Union[Session, Connection, AsyncSession, AsyncConnection, None] = None,
9494
**kwargs: _P.kwargs,
9595
) -> Union[_T, _R]: ...
96+
async def async_refresh(
97+
self, instance, attribute_names=None, with_for_update=None, session: Optional[AsyncSession] = None
98+
): ...
9699
async def asgi_dispatch(self, request, call_next): ...
97100
def __call__(self) -> AsyncSessionContextVarManager:
98101
pass

sqlalchemy_database/database.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,28 @@ def get_user(session:Session,id:int):
427427
await executor.commit()
428428
return result
429429

430+
async def refresh(self, instance, attribute_names=None, with_for_update=None, session: Optional[AsyncSession] = None):
431+
"""
432+
Refresh the attributes of the given instance from the database.
433+
Args:
434+
instance: The instance to be refreshed.
435+
attribute_names: Optional list of attribute names to refresh.
436+
with_for_update: optional boolean ``True`` indicating FOR UPDATE should be used,
437+
or may be a dictionary containing flags to
438+
indicate a more specific set of FOR UPDATE flags for the SELECT;
439+
flags should match the parameters of :meth:`_query.Query.with_for_update`.
440+
Supersedes the :paramref:`.Session.refresh.lockmode` parameter.
441+
session: If not specified, an `AsyncSession` is created.
442+
"""
443+
need_close = False
444+
if session is None or not isinstance(session, AsyncSession):
445+
session = self.session
446+
if session is None:
447+
need_close = True
448+
session = self.session_maker()
449+
async with ExecutorContextManager(session, need_close=need_close) as session:
450+
await session.refresh(instance, attribute_names, with_for_update)
451+
430452

431453
class Database(AbcAsyncDatabase):
432454
"""`sqlalchemy` synchronous database client"""
@@ -622,6 +644,16 @@ def run_sync(
622644
executor.commit()
623645
return result
624646

647+
def refresh(self, instance, attribute_names=None, with_for_update=None, session: Optional[Session] = None):
648+
need_close = False
649+
if session is None or not isinstance(session, Session):
650+
session = self.session
651+
if session is None:
652+
need_close = True
653+
session = self.session_maker()
654+
with ExecutorContextManager(session, need_close=need_close) as session:
655+
session.refresh(instance, attribute_names=attribute_names, with_for_update=with_for_update)
656+
625657

626658
class ExecutorContextManager:
627659
"""Actuator context manager, optionally closing the executor"""

0 commit comments

Comments
 (0)