Skip to content

Commit ff801c4

Browse files
committed
feat: Optimize the code and add the get_engine_url method
1 parent 31c9cf7 commit ff801c4

8 files changed

Lines changed: 74 additions & 16 deletions

File tree

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ classifiers = [
3737
]
3838
dependencies = [
3939
"sqlalchemy",
40-
"sqlalchemy2-stubs>=0.0.2a29"
4140
]
4241

4342
[project.urls]

sqlalchemy_database/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
__version__ = "0.1.0"
22
__url__ = "https://github.com/amisadmin/sqlalchemy_database"
33

4-
from sqlalchemy_database._abc_async_database import AbcAsyncDatabase
5-
from sqlalchemy_database.database import AsyncDatabase, Database
4+
from sqlalchemy_database.database import AbcAsyncDatabase, AsyncDatabase, Database
65

76
__all__ = ["AsyncDatabase", "Database", "AbcAsyncDatabase"]

sqlalchemy_database/_abc_async_database.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import abc
22
import asyncio
33
import functools
4-
from typing import Callable, TypeVar
4+
from typing import Callable, Dict, TypeVar, Union
55

6+
from sqlalchemy.ext.asyncio import AsyncEngine
7+
from sqlalchemy.future import Engine
68
from sqlalchemy.orm import scoped_session
79

810
try:
@@ -23,17 +25,19 @@ async def to_thread(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs)
2325

2426

2527
class AbcAsyncDatabase(metaclass=abc.ABCMeta): # noqa: B024
26-
def __new__(cls, engine, *args, **kwargs):
28+
29+
_instances: Dict[str, "AbcAsyncDatabase"] = None
30+
31+
def __new__(cls, engine: Union[Engine, AsyncEngine], *args, **kwargs):
2732
"""Create a new instance of the database class.Each engine url corresponds to a database instance,
2833
and if it already exists, it is directly returned, otherwise a new instance is created.
2934
"""
30-
if not hasattr(cls, "_instances"):
31-
cls._instances = {}
35+
cls._instances = cls._instances or {}
3236
if engine.url not in cls._instances:
3337
cls._instances[engine.url] = super().__new__(cls)
3438
return cls._instances[engine.url]
3539

36-
def __init__(self) -> None:
40+
def __init__(self, engine: Union[Engine, AsyncEngine], *args, **kwargs) -> None:
3741
for func_name in {
3842
"run_sync",
3943
"begin",
@@ -98,3 +102,16 @@ async def asgi_dispatch(self, request, call_next):
98102
async with self.__call__(scope=id(request.scope)):
99103
request.scope[f"__sqlalchemy_database__:{id(self)}"] = self
100104
return await call_next(request)
105+
106+
def attach_middleware(self, app):
107+
"""Attach the middleware to the ASGI application.
108+
Example:
109+
```Python
110+
app = FastAPI()
111+
db = Database.create("sqlite:///test.db")
112+
db.attach_middlewares(app)
113+
```
114+
"""
115+
from starlette.middleware.base import BaseHTTPMiddleware
116+
117+
app.add_middleware(BaseHTTPMiddleware, dispatch=self.asgi_dispatch)

sqlalchemy_database/_abc_async_database.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ from typing import (
1111
Union,
1212
)
1313

14-
from sqlalchemy.engine import Connection, Result
14+
from sqlalchemy.engine import Connection, Engine, Result
1515
from sqlalchemy.sql import ClauseElement, Executable
1616
from sqlmodel.engine.result import ScalarResult
1717
from typing_extensions import Concatenate, ParamSpec
@@ -22,19 +22,23 @@ try:
2222
from sqlmodel import Session
2323
from sqlmodel.ext.asyncio.session import AsyncSession
2424
except ImportError:
25-
from sqlalchemy.ext.asyncio import AsyncSession
25+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
2626
from sqlalchemy.orm import Session
2727

2828
_P = ParamSpec("_P")
2929
_T = TypeVar("_T")
3030
_R = TypeVar("_R")
3131

32+
async def to_thread(func: Callable[_P, _R], *args: _P.args, **kwargs: _P.kwargs) -> _R: ...
33+
3234
_ExecuteParams = Union[Mapping[Any, Any], Sequence[Mapping[Any, Any]]]
3335
_ExecuteOptions = Mapping[Any, Any]
3436

3537
class AbcAsyncDatabase(metaclass=abc.ABCMeta):
3638
"""`sqlalchemy` asynchronous database abstract base class, not directly instantiated"""
3739

40+
engine: Union[Engine, AsyncEngine]
41+
3842
async def async_run_sync(
3943
self,
4044
fn: Callable[[Concatenate[Union[Session, Connection], _P]], _T],

sqlalchemy_database/database.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Union,
1212
)
1313

14-
from sqlalchemy.engine import Connection
14+
from sqlalchemy.engine import URL, Connection
1515
from sqlalchemy.ext.asyncio import (
1616
AsyncEngine,
1717
async_scoped_session,
@@ -51,7 +51,6 @@ def __init__(
5151
commit_on_exit: Whether to commit the session when the context manager or session generator exits.
5252
**session_options: The default `session` initialization parameters
5353
"""
54-
5554
self.engine: AsyncEngine = engine
5655
"""`sqlalchemy` Asynchronous Engine
5756
@@ -78,7 +77,7 @@ def __init__(
7877
f"_session_context_var_{id(self)}", default=None
7978
)
8079
self.scoped_session: async_scoped_session = async_scoped_session(self.session_maker, scopefunc=self._session_scope.get)
81-
super().__init__()
80+
super().__init__(engine)
8281

8382
@property
8483
def session(self) -> AsyncSession:
@@ -115,7 +114,7 @@ def __call__(self, scope: Any = None):
115114

116115
@classmethod
117116
def create(
118-
cls, url: str, *, commit_on_exit: bool = True, session_options: Mapping[str, Any] = None, **kwargs
117+
cls, url: Union[str, URL], *, commit_on_exit: bool = True, session_options: Mapping[str, Any] = None, **kwargs
119118
) -> "AsyncDatabase":
120119
"""
121120
Initialize the client with a database connection string
@@ -207,7 +206,7 @@ def __init__(self, engine: Engine, commit_on_exit: bool = True, **session_option
207206
self._session_scope: ContextVar[Union[str, Session, None]] = ContextVar(f"_session_context_var_{id(self)}", default=None)
208207
self.scoped_session: scoped_session = scoped_session(self.session_maker, scopefunc=self._session_scope.get)
209208
"""Returns the Session local instance for the current context or current thread."""
210-
super().__init__()
209+
super().__init__(engine)
211210

212211
@property
213212
def session(self) -> Session:
@@ -222,7 +221,7 @@ def __call__(self, scope: Any = None):
222221

223222
@classmethod
224223
def create(
225-
cls, url: str, *, commit_on_exit: bool = True, session_options: Optional[Mapping[str, Any]] = None, **kwargs
224+
cls, url: Union[str, URL], *, commit_on_exit: bool = True, session_options: Optional[Mapping[str, Any]] = None, **kwargs
226225
) -> "Database":
227226
kwargs.setdefault("future", True)
228227
engine = create_engine(url, **kwargs)

sqlalchemy_database/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Union
2+
3+
from sqlalchemy.engine import URL, make_url
4+
5+
SQLALCHEMY_DRIVER = {
6+
"sqlite": {"sync": ["pysqlite", "pysqlcipher"], "async": ["aiosqlite"]},
7+
"mysql": {"sync": ["pymysql", "mysqldb", "mysqlconnector", "cymysql", "pyodbc"], "async": ["aiomysql", "asyncmy"]},
8+
"mariadb": {"sync": ["pymysql", "mysqldb", "mysqlconnector", "cymysql", "pyodbc"], "async": ["aiomysql", "asyncmy"]},
9+
"postgresql": {"sync": ["pg8000", "pyscopg2", "psycopg", "psycopg2cffi"], "async": ["asyncpg"]},
10+
"oracle": {"sync": ["cx_oracle", "oracledb"], "async": []},
11+
"mssql": {"sync": ["pyodbc", "pymssql"], "async": []},
12+
}
13+
14+
15+
def get_engine_url(url: Union[str, URL], sync: bool = True) -> URL:
16+
url: URL = make_url(url)
17+
backend_name = url.get_backend_name()
18+
driver_name = url.get_driver_name()
19+
driver_type = "sync" if sync else "async"
20+
if driver_name in SQLALCHEMY_DRIVER.get(backend_name, {}).get(driver_type, []):
21+
return url
22+
new_driver = SQLALCHEMY_DRIVER[backend_name][driver_type][0]
23+
url = url.set(drivername=f"{backend_name}+{new_driver}")
24+
return url

tests/test_Database.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy import select
66
from sqlalchemy.orm import Session
77

8+
from sqlalchemy_database import Database
89
from tests.conftest import Group, User, sync_db
910

1011

@@ -124,3 +125,9 @@ def test_ThreadPoolExecutor():
124125
done, fail = wait(all_task, return_when=ALL_COMPLETED) # 等待线程运行完毕
125126
results = {task.result() for task in done}
126127
assert len(results) == task_count
128+
129+
130+
def test_create():
131+
sync_db1 = Database.create("sqlite:///amisadmin.db?check_same_thread=False")
132+
sync_db2 = Database.create("sqlite:///amisadmin.db?check_same_thread=False")
133+
assert sync_db2 is sync_db1

tests/test_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from sqlalchemy_database.utils import get_engine_url
2+
from tests.conftest import async_db, sync_db
3+
4+
5+
def test_get_engine_url():
6+
assert get_engine_url(sync_db.engine.url, sync=True) == sync_db.engine.url
7+
assert get_engine_url(sync_db.engine.url, sync=False) == async_db.engine.url
8+
assert get_engine_url(async_db.engine.url, sync=True) == sync_db.engine.url.set(drivername="sqlite+pysqlite")
9+
assert get_engine_url(async_db.engine.url, sync=False) == async_db.engine.url

0 commit comments

Comments
 (0)