11from contextvars import ContextVar
2- from threading import Lock
32from typing import (
43 Any ,
54 AsyncGenerator ,
@@ -69,8 +68,6 @@ def __init__(self, engine: AsyncEngine, **session_options):
6968 await session.commit()
7069 ```
7170 """
72- self ._session_lock = Lock ()
73- self ._session_enter_count = 0
7471 self ._session_context_var : ContextVar [Optional [AsyncSession ]] = ContextVar ("_session_context_var" , default = None )
7572
7673 @property
@@ -82,37 +79,22 @@ def session(self) -> Optional[AsyncSession]:
8279 Example:
8380 ```Python
8481 app = FastAPI()
85- app.add_middleware(BaseHTTPMiddleware, db.asgi_dispatch)
82+ app.add_middleware(BaseHTTPMiddleware, dispatch= db.asgi_dispatch)
8683
8784 @app.get('/get_user')
8885 async def get_user(id:int):
8986 return await db.session.get(User,id)
9087 ```
9188 In ordinary methods, session will return None. You can get it through:
9289 ```Python
93- async with db:
90+ async with db() :
9491 db.session.get(User,id)
9592 ```
9693 """
97- return self ._session_context_var .get () if self . _session_enter_count > 0 else None
94+ return self ._session_context_var .get ()
9895
99- async def __aenter__ (self ):
100- with self ._session_lock :
101- session = self .session
102- if session is None :
103- session = self .session_maker ()
104- self ._session_context_var_token = self ._session_context_var .set (session )
105- self ._session_enter_count += 1
106- return session
107-
108- async def __aexit__ (self , exc_type , exc_value , traceback ):
109- with self ._session_lock :
110- self ._session_enter_count -= 1
111- if self ._session_enter_count <= 0 :
112- session = self ._session_context_var .get ()
113- if session is not None :
114- await session .close ()
115- self ._session_context_var .reset (self ._session_context_var_token )
96+ def __call__ (self ):
97+ return AsyncSessionContextVarManager (self )
11698
11799 @classmethod
118100 def create (cls , url : str , * , session_options : Mapping [str , Any ] = None , ** kwargs ) -> "AsyncDatabase" :
@@ -428,8 +410,6 @@ def get_user(session:Session,id:int):
428410 APIs which will be properly adapted to the greenlet context.
429411 """
430412 need_close = False
431- if executor is None and is_session :
432- executor = self .session
433413 if executor is None or not isinstance (executor , (AsyncSession , AsyncConnection )):
434414 if is_session :
435415 executor = self .session
@@ -456,38 +436,15 @@ def __init__(self, engine: Engine, **session_options):
456436 self .engine : Engine = engine
457437 session_options .setdefault ("class_" , Session )
458438 self .session_maker : Callable [..., Session ] = sessionmaker (self .engine , ** session_options )
459- self ._session_lock = Lock ()
460- self ._session_enter_count = 0
461439 self ._session_context_var : ContextVar [Optional [Session ]] = ContextVar ("_session_context_var" , default = None )
462440
463441 @property
464442 def session (self ) -> Optional [Session ]:
465443 """Return an instance of Session local to the current context."""
466- return self ._session_context_var .get () if self . _session_enter_count > 0 else None
444+ return self ._session_context_var .get ()
467445
468- def __enter__ (self ):
469- with self ._session_lock :
470- session = self .session
471- if session is None :
472- session = self .session_maker ()
473- self ._session_context_var_token = self ._session_context_var .set (session )
474- self ._session_enter_count += 1
475- return session
476-
477- def __exit__ (self , exc_type , exc_value , traceback ):
478- with self ._session_lock :
479- self ._session_enter_count -= 1
480- if self ._session_enter_count <= 0 :
481- session = self ._session_context_var .get ()
482- if session is not None :
483- session .close ()
484- self ._session_context_var .reset (self ._session_context_var_token )
485-
486- async def __aenter__ (self ):
487- return self .__enter__ ()
488-
489- async def __aexit__ (self , exc_type , exc_value , traceback ):
490- return self .__exit__ (exc_type , exc_value , traceback )
446+ def __call__ (self ):
447+ return SessionContextVarManager (self )
491448
492449 @classmethod
493450 def create (cls , url : str , * , session_options : Optional [Mapping [str , Any ]] = None , ** kwargs ) -> "Database" :
@@ -514,8 +471,6 @@ def execute(
514471 ** kw : Any ,
515472 ) -> Union [Result , _T ]:
516473 need_close = False
517- if executor is None and is_session :
518- executor = self .session
519474 if executor is None or not isinstance (executor , (Session , Connection )):
520475 need_close = True
521476 if is_session :
@@ -688,3 +643,45 @@ async def __aenter__(self):
688643 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
689644 if self .need_close :
690645 await self .executor .close ()
646+
647+
648+ class AsyncSessionContextVarManager :
649+ def __init__ (self , db : AsyncDatabase ):
650+ self .db = db
651+ self .token = None
652+
653+ async def __aenter__ (self ):
654+ session = self .db .session_maker ()
655+ self .token = self .db ._session_context_var .set (session )
656+ return session
657+
658+ async def __aexit__ (self , exc_type , exc_value , traceback ):
659+ session = self .db ._session_context_var .get ()
660+ if exc_type is not None :
661+ await session .rollback ()
662+ await session .close ()
663+ self .db ._session_context_var .reset (self .token )
664+
665+
666+ class SessionContextVarManager :
667+ def __init__ (self , db : Database ):
668+ self .db = db
669+ self .token = None
670+
671+ def __enter__ (self ):
672+ session = self .db .session_maker ()
673+ self .token = self .db ._session_context_var .set (session )
674+ return session
675+
676+ def __exit__ (self , exc_type , exc_value , traceback ):
677+ session = self .db ._session_context_var .get ()
678+ if exc_type is not None :
679+ session .rollback ()
680+ session .close ()
681+ self .db ._session_context_var .reset (self .token )
682+
683+ async def __aenter__ (self ):
684+ return self .__enter__ ()
685+
686+ async def __aexit__ (self , exc_type , exc_value , traceback ):
687+ self .__exit__ (exc_type , exc_value , traceback )
0 commit comments