Skip to content

Commit cc5e32f

Browse files
committed
feat: try importing Session from SQLModel
1 parent 89a7641 commit cc5e32f

5 files changed

Lines changed: 34 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ test = [
4949
"pytest-cov",
5050
"pytest-asyncio>=0.17",
5151
"httpx",
52+
"sqlmodel",
5253
]
5354
docs = [
5455
"mkdocs-material>=8.3.8",

sqlalchemy_database/_abc_async_database.pyi

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@ import abc
22
from typing import Any, Optional, Mapping, Union, Sequence, Type, List, TypeVar, Callable
33

44
from sqlalchemy.engine import Result, Connection
5-
from sqlalchemy.ext.asyncio import AsyncSession, AsyncConnection
6-
from sqlalchemy.orm import Session
5+
from sqlalchemy.ext.asyncio import AsyncConnection
76
from sqlalchemy.sql import Executable
87
from typing_extensions import ParamSpec, Concatenate
98

9+
try:
10+
from sqlmodel import Session
11+
from sqlmodel.ext.asyncio.session import AsyncSession
12+
except ImportError:
13+
from sqlalchemy.orm import Session
14+
from sqlalchemy.ext.asyncio import AsyncSession
15+
1016
_P = ParamSpec("_P")
1117
_T = TypeVar("_T")
1218
_R = TypeVar("_R")

sqlalchemy_database/database.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
from typing import Generator, Any, AsyncGenerator, Optional, Mapping, Union, Sequence, Type, List, Callable, TypeVar
22

33
from sqlalchemy.engine import Result, Connection
4-
from sqlalchemy.ext.asyncio import AsyncSession, AsyncEngine, create_async_engine, AsyncConnection
4+
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine, AsyncConnection
55
from sqlalchemy.future import Engine, create_engine
6-
from sqlalchemy.orm import sessionmaker, Session
6+
from sqlalchemy.orm import sessionmaker
77
from sqlalchemy.sql import Executable, Select
88
from typing_extensions import Concatenate, ParamSpec
99

10+
try:
11+
from sqlmodel import Session
12+
from sqlmodel.ext.asyncio.session import AsyncSession
13+
except ImportError:
14+
from sqlalchemy.orm import Session
15+
from sqlalchemy.ext.asyncio import AsyncSession
16+
1017
from sqlalchemy_database._abc_async_database import AbcAsyncDatabase
1118

1219
_P = ParamSpec("_P")
@@ -368,6 +375,7 @@ class Database(AbcAsyncDatabase):
368375
def __init__(self, engine: Engine, **session_options):
369376
super().__init__()
370377
self.engine: Engine = engine
378+
session_options.setdefault('class_', Session)
371379
self.session_maker: Callable[..., Session] = sessionmaker(self.engine, **session_options)
372380

373381
@classmethod

tests/test_AsyncDatabase.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,11 @@ async def test_executor(fake_users):
154154
users = await db.scalars_all(select(User), session = session)
155155
for user in users:
156156
assert user.group is None if user.group_id is None else user.group
157+
158+
async def test_sqlmodel_session(fake_users):
159+
from sqlmodel import select
160+
161+
async with db.session_maker() as session:
162+
result = await session.exec(select(User))
163+
user = result.first()
164+
assert user.id == 1

tests/test_Database.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,10 @@ def test_executor(fake_users):
153153
users = db.scalars_all(select(User), session = session)
154154
for user in users:
155155
assert user.group is None if user.group_id is None else user.group
156+
157+
def test_sqlmodel_session(fake_users):
158+
from sqlmodel import select
159+
160+
with db.session_maker() as session:
161+
user = session.exec(select(User)).first()
162+
assert user.id == 1

0 commit comments

Comments
 (0)