Skip to content

Commit 4a22504

Browse files
committed
fix: Replace asgi_dispatch with asgi_middleware.
1 parent ff801c4 commit 4a22504

3 files changed

Lines changed: 41 additions & 14 deletions

File tree

sqlalchemy_database/_abc_async_database.py

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22
import asyncio
33
import functools
4+
import warnings
45
from typing import Callable, Dict, TypeVar, Union
56

67
from sqlalchemy.ext.asyncio import AsyncEngine
@@ -85,17 +86,15 @@ def __init__(self, engine: Union[Engine, AsyncEngine], *args, **kwargs) -> None:
8586
setattr(self, f"async_{func_name}", func)
8687

8788
async def asgi_dispatch(self, request, call_next):
88-
"""Middleware for ASGI applications, such as: Starlette, FastAPI, Quart, Sanic, Hug, Responder, etc.
89-
Bind a SQLAlchemy session connection to the incoming HTTP request session context,
90-
you can access the session object through `self.session`.
91-
The instance shortcut method will also try to use this `session` object by default.
92-
Example:
93-
```Python
94-
app = FastAPI()
95-
db = Database.create("sqlite:///test.db")
96-
app.add_middleware(BaseHTTPMiddleware, dispatch=db.asgi_dispatch)
97-
```
9889
"""
90+
This method has been deprecated and is not recommended. Please use the `asgi_middleware` method instead.
91+
Reference: https://www.starlette.io/middleware/#limitations
92+
"""
93+
# 打印警告信息
94+
warnings.warn(
95+
"This method has been deprecated and is not recommended. Please use the `asgi_middleware` method instead.",
96+
DeprecationWarning,
97+
)
9998
if request.scope.get(f"__sqlalchemy_database__:{id(self)}", False):
10099
return await call_next(request)
101100
# bind session to request
@@ -112,6 +111,32 @@ def attach_middleware(self, app):
112111
db.attach_middlewares(app)
113112
```
114113
"""
115-
from starlette.middleware.base import BaseHTTPMiddleware
114+
app.add_middleware(self.asgi_middleware)
115+
116+
@property
117+
def asgi_middleware(self):
118+
"""Middleware for ASGI applications, such as: Starlette, FastAPI, Quart, Sanic, Hug, Responder, etc.
119+
Bind a SQLAlchemy session connection to the incoming HTTP request session context,
120+
you can access the session object through `self.session`.
121+
The instance shortcut method will also try to use this `session` object by default.
122+
Example:
123+
```Python
124+
app = FastAPI()
125+
db = Database.create("sqlite:///test.db")
126+
app.add_middleware(db.asgi_middleware)
127+
```
128+
"""
129+
130+
def asgi_decorator(app):
131+
@functools.wraps(app)
132+
async def wrapped_app(scope, receive, send):
133+
if scope.get(f"__sqlalchemy_database__:{id(self)}", False):
134+
return await app(scope, receive, send)
135+
# bind session to request
136+
async with self.__call__(scope=id(scope)):
137+
scope[f"__sqlalchemy_database__:{id(self)}"] = self
138+
await app(scope, receive, send)
139+
140+
return wrapped_app
116141

117-
app.add_middleware(BaseHTTPMiddleware, dispatch=self.asgi_dispatch)
142+
return asgi_decorator

sqlalchemy_database/_abc_async_database.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ from typing import (
66
Mapping,
77
Optional,
88
Sequence,
9+
Tuple,
910
Type,
1011
TypeVar,
1112
Union,
@@ -46,7 +47,8 @@ class AbcAsyncDatabase(metaclass=abc.ABCMeta):
4647
is_session: bool = True,
4748
**kwargs: _P.kwargs,
4849
) -> _T: ...
49-
async def asgi_dispatch(self, request, call_next): ...
50+
def asgi_middleware(self, app: Any) -> Callable[[Any], Tuple[Mapping[str, Any], Any, Any]]: ...
51+
def attach_middleware(self, app: Any) -> None: ...
5052
def __call__(self, scope: Any = None) -> AsyncSessionContextVarManager:
5153
pass
5254
async def async_close(self) -> None: ...

tests/test_fastapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def test_async_db_in_fastapi():
8080
app = FastAPI()
8181
sub_app = FastAPI()
8282
app.mount("/sub", sub_app)
83-
app.add_middleware(BaseHTTPMiddleware, dispatch=async_db.asgi_dispatch)
83+
app.add_middleware(async_db.asgi_middleware)
8484
client = TestClient(app)
8585

8686
@app.get("/users")

0 commit comments

Comments
 (0)