Skip to content

Commit 101a193

Browse files
committed
fix: Fix Session ContextVar Manager
1 parent 2b6738c commit 101a193

6 files changed

Lines changed: 63 additions & 62 deletions

File tree

sqlalchemy_database/_abc_async_database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ async def asgi_dispatch(self, request, call_next):
3737
```Python
3838
app = FastAPI()
3939
db = Database.create("sqlite:///test.db")
40-
app.add_middleware(BaseHTTPMiddleware, db.asgi_dispatch)
40+
app.add_middleware(BaseHTTPMiddleware, dispatch=db.asgi_dispatch)
4141
```
4242
"""
43-
async with self:
43+
async with self.__call__():
4444
response = await call_next(request)
4545
return response

sqlalchemy_database/_abc_async_database.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ from sqlalchemy.ext.asyncio import AsyncConnection
1616
from sqlalchemy.sql import Executable
1717
from typing_extensions import Concatenate, ParamSpec
1818

19+
from sqlalchemy_database.database import AsyncSessionContextVarManager
20+
1921
try:
2022
from sqlmodel import Session
2123
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -92,3 +94,5 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
9294
**kwargs: _P.kwargs,
9395
) -> Union[_T, _R]: ...
9496
async def asgi_dispatch(self, request, call_next): ...
97+
def __call__(self) -> AsyncSessionContextVarManager:
98+
pass

sqlalchemy_database/database.py

Lines changed: 50 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from contextvars import ContextVar
2-
from threading import Lock
32
from typing import (
43
Any,
54
AsyncGenerator,
@@ -69,8 +68,6 @@ def __init__(self, engine: AsyncEngine, **session_options):
6968
await session.commit()
7069
```
7170
"""
72-
self._session_lock = Lock()
73-
self._session_enter_count = 0
7471
self._session_context_var: ContextVar[Optional[AsyncSession]] = ContextVar("_session_context_var", default=None)
7572

7673
@property
@@ -82,37 +79,22 @@ def session(self) -> Optional[AsyncSession]:
8279
Example:
8380
```Python
8481
app = FastAPI()
85-
app.add_middleware(BaseHTTPMiddleware, db.asgi_dispatch)
82+
app.add_middleware(BaseHTTPMiddleware, dispatch=db.asgi_dispatch)
8683
8784
@app.get('/get_user')
8885
async def get_user(id:int):
8986
return await db.session.get(User,id)
9087
```
9188
In ordinary methods, session will return None. You can get it through:
9289
```Python
93-
async with db:
90+
async with db():
9491
db.session.get(User,id)
9592
```
9693
"""
97-
return self._session_context_var.get() if self._session_enter_count > 0 else None
94+
return self._session_context_var.get()
9895

99-
async def __aenter__(self):
100-
with self._session_lock:
101-
session = self.session
102-
if session is None:
103-
session = self.session_maker()
104-
self._session_context_var_token = self._session_context_var.set(session)
105-
self._session_enter_count += 1
106-
return session
107-
108-
async def __aexit__(self, exc_type, exc_value, traceback):
109-
with self._session_lock:
110-
self._session_enter_count -= 1
111-
if self._session_enter_count <= 0:
112-
session = self._session_context_var.get()
113-
if session is not None:
114-
await session.close()
115-
self._session_context_var.reset(self._session_context_var_token)
96+
def __call__(self):
97+
return AsyncSessionContextVarManager(self)
11698

11799
@classmethod
118100
def create(cls, url: str, *, session_options: Mapping[str, Any] = None, **kwargs) -> "AsyncDatabase":
@@ -428,8 +410,6 @@ def get_user(session:Session,id:int):
428410
APIs which will be properly adapted to the greenlet context.
429411
"""
430412
need_close = False
431-
if executor is None and is_session:
432-
executor = self.session
433413
if executor is None or not isinstance(executor, (AsyncSession, AsyncConnection)):
434414
if is_session:
435415
executor = self.session
@@ -456,38 +436,15 @@ def __init__(self, engine: Engine, **session_options):
456436
self.engine: Engine = engine
457437
session_options.setdefault("class_", Session)
458438
self.session_maker: Callable[..., Session] = sessionmaker(self.engine, **session_options)
459-
self._session_lock = Lock()
460-
self._session_enter_count = 0
461439
self._session_context_var: ContextVar[Optional[Session]] = ContextVar("_session_context_var", default=None)
462440

463441
@property
464442
def session(self) -> Optional[Session]:
465443
"""Return an instance of Session local to the current context."""
466-
return self._session_context_var.get() if self._session_enter_count > 0 else None
444+
return self._session_context_var.get()
467445

468-
def __enter__(self):
469-
with self._session_lock:
470-
session = self.session
471-
if session is None:
472-
session = self.session_maker()
473-
self._session_context_var_token = self._session_context_var.set(session)
474-
self._session_enter_count += 1
475-
return session
476-
477-
def __exit__(self, exc_type, exc_value, traceback):
478-
with self._session_lock:
479-
self._session_enter_count -= 1
480-
if self._session_enter_count <= 0:
481-
session = self._session_context_var.get()
482-
if session is not None:
483-
session.close()
484-
self._session_context_var.reset(self._session_context_var_token)
485-
486-
async def __aenter__(self):
487-
return self.__enter__()
488-
489-
async def __aexit__(self, exc_type, exc_value, traceback):
490-
return self.__exit__(exc_type, exc_value, traceback)
446+
def __call__(self):
447+
return SessionContextVarManager(self)
491448

492449
@classmethod
493450
def create(cls, url: str, *, session_options: Optional[Mapping[str, Any]] = None, **kwargs) -> "Database":
@@ -514,8 +471,6 @@ def execute(
514471
**kw: Any,
515472
) -> Union[Result, _T]:
516473
need_close = False
517-
if executor is None and is_session:
518-
executor = self.session
519474
if executor is None or not isinstance(executor, (Session, Connection)):
520475
need_close = True
521476
if is_session:
@@ -688,3 +643,45 @@ async def __aenter__(self):
688643
async def __aexit__(self, exc_type, exc_val, exc_tb):
689644
if self.need_close:
690645
await self.executor.close()
646+
647+
648+
class AsyncSessionContextVarManager:
649+
def __init__(self, db: AsyncDatabase):
650+
self.db = db
651+
self.token = None
652+
653+
async def __aenter__(self):
654+
session = self.db.session_maker()
655+
self.token = self.db._session_context_var.set(session)
656+
return session
657+
658+
async def __aexit__(self, exc_type, exc_value, traceback):
659+
session = self.db._session_context_var.get()
660+
if exc_type is not None:
661+
await session.rollback()
662+
await session.close()
663+
self.db._session_context_var.reset(self.token)
664+
665+
666+
class SessionContextVarManager:
667+
def __init__(self, db: Database):
668+
self.db = db
669+
self.token = None
670+
671+
def __enter__(self):
672+
session = self.db.session_maker()
673+
self.token = self.db._session_context_var.set(session)
674+
return session
675+
676+
def __exit__(self, exc_type, exc_value, traceback):
677+
session = self.db._session_context_var.get()
678+
if exc_type is not None:
679+
session.rollback()
680+
session.close()
681+
self.db._session_context_var.reset(self.token)
682+
683+
async def __aenter__(self):
684+
return self.__enter__()
685+
686+
async def __aexit__(self, exc_type, exc_value, traceback):
687+
self.__exit__(exc_type, exc_value, traceback)

tests/test_AbcAsyncDatabase.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def get_user(session: Session, user_id: int):
127127

128128

129129
async def test_async_session_context_var(db, fake_users):
130-
async with db:
130+
async with db():
131131
# test db function
132132
user = await db.async_get(User, 1)
133133
assert user.id == 1

tests/test_AsyncDatabase.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,16 +175,16 @@ async def test_sqlmodel_session(fake_users):
175175

176176

177177
async def test_async_session_context_var(fake_users):
178-
async with db as session:
178+
async with db() as session:
179179
# test enter return session
180180
user = await session.get(User, 1)
181181
assert user.id == 1
182182

183183
# test nested session
184-
async with db as session2:
184+
async with db() as session2:
185185
user = await session2.get(User, 1)
186186
assert user.id == 1
187-
assert session is session2
187+
assert session is not session2
188188
# test db.session
189189
user = await db.session.get(User, 1)
190190
assert user.id == 1

tests/test_Database.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,16 @@ def test_sqlmodel_session(fake_users):
173173

174174

175175
def test_session_context_var(fake_users):
176-
with db as session:
176+
with db() as session:
177177
# test enter return session
178178
user = session.get(User, 1)
179179
assert user.id == 1
180180

181181
# test nested session
182-
with db as session2:
182+
with db() as session2:
183183
user = session2.get(User, 1)
184184
assert user.id == 1
185-
assert session is session2
185+
assert session is not session2
186186

187187
# test db.session
188188
user = db.session.get(User, 1)

0 commit comments

Comments
 (0)