Skip to content

Commit 07dfaf3

Browse files
committed
perf: Add the scoped attribute to determine whether the session attribute is scoped.
1 parent bd228a5 commit 07dfaf3

3 files changed

Lines changed: 17 additions & 3 deletions

File tree

sqlalchemy_database/database.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,13 @@ async def get_user(id:int):
107107
"""
108108
return self.scoped_session()
109109

110+
@property
111+
def scoped(self) -> bool:
112+
"""Whether the current context has a session. If False, the session is the default global session,
113+
and the transaction needs to be manually submitted.
114+
"""
115+
return bool(self._session_scope.get())
116+
110117
def __call__(self, scope: Any = None):
111118
return AsyncSessionContextVarManager(self, scope=scope)
112119

@@ -140,7 +147,7 @@ async def session_generator(self) -> AsyncGenerator[AsyncSession, Any]:
140147
return await session.get(User,id)
141148
```
142149
"""
143-
if self._session_scope.get():
150+
if self.scoped:
144151
"""If the current context has a session, return it."""
145152
yield self.session
146153
else:
@@ -210,6 +217,10 @@ def __init__(self, engine: Engine, commit_on_exit: bool = True, **session_option
210217
def session(self) -> Session:
211218
return self.scoped_session()
212219

220+
@property
221+
def scoped(self) -> bool:
222+
return bool(self._session_scope.get())
223+
213224
def __call__(self, scope: Any = None):
214225
return SessionContextVarManager(self, scope=scope)
215226

@@ -223,7 +234,7 @@ def create(
223234
return cls(engine, **session_options)
224235

225236
def session_generator(self) -> Generator[Session, Any, None]:
226-
if self._session_scope.get():
237+
if self.scoped:
227238
"""If the current context has a session, return it."""
228239
yield self.session
229240
else:

tests/test_AsyncDatabase.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,9 @@ def lock(event_loop: AbstractEventLoop):
9898

9999
async def test_async_session_context_var(lock, i=1):
100100
global_session = async_db.session # Default global session
101-
101+
assert not async_db.scoped
102102
async with async_db() as session: # Enter a new session
103+
assert async_db.scoped
103104
user = await session.get(User, 1)
104105
assert user.id == 1
105106
assert session is async_db.session

tests/test_Database.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def test_sqlmodel_session():
8484

8585
def test_session_context_var(i=1):
8686
global_session = sync_db.session # Default global session
87+
assert not sync_db.scoped
8788
with sync_db() as session: # Enter a new session
89+
assert sync_db.scoped
8890
user = session.get(User, 1)
8991
assert user.id == 1
9092
assert session is sync_db.session

0 commit comments

Comments
 (0)