Skip to content

Commit a403bcf

Browse files
Goradiimmzeynalli
andauthored
Support MappedAsDataclass (#857)
* tests: to upcoming changes now fails * fix: insert dataclass support * Refactored tests and cleaned _queries.py --------- Co-authored-by: Miradil Zeynalli <miradil.zeynalli@gmail.com>
1 parent a54c013 commit a403bcf

4 files changed

Lines changed: 189 additions & 2 deletions

File tree

sqladmin/_queries.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import dataclasses
34
from typing import TYPE_CHECKING, Any
45

56
import anyio
@@ -191,8 +192,20 @@ async def _delete_async(self, pk: str, request: Request) -> None:
191192
await session.commit()
192193
await self.model_view.after_model_delete(obj, request)
193194

195+
def _get_model_object(self, data: dict[str, Any]) -> Any:
196+
if dataclasses.is_dataclass(self.model_view.model):
197+
init_fields = {
198+
f.name for f in dataclasses.fields(self.model_view.model) if f.init
199+
}
200+
data = {k: v for k, v in data.items() if k in init_fields}
201+
202+
else:
203+
data = {}
204+
205+
return self.model_view.model(**data)
206+
194207
def _insert_sync(self, data: dict[str, Any], request: Request) -> Any:
195-
obj = self.model_view.model()
208+
obj = self._get_model_object(data)
196209

197210
with self.model_view.session_maker(expire_on_commit=False) as session:
198211
anyio.from_thread.run(
@@ -207,7 +220,7 @@ def _insert_sync(self, data: dict[str, Any], request: Request) -> Any:
207220
return obj
208221

209222
async def _insert_async(self, data: dict[str, Any], request: Request) -> Any:
210-
obj = self.model_view.model()
223+
obj = self._get_model_object(data)
211224

212225
async with self.model_view.session_maker(expire_on_commit=False) as session:
213226
await self.model_view.on_model_change(data, obj, True, request)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import pytest
2+
from sqlalchemy import __version__ as __sa_version__
3+
4+
if __sa_version__.startswith("1."):
5+
pytest.skip("SQLAlchemy 1.4 does not support dataclasses", allow_module_level=True)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from typing import AsyncGenerator
2+
3+
import pytest
4+
from httpx import ASGITransport, AsyncClient
5+
from sqlalchemy import Integer, String, func, select
6+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
7+
from sqlalchemy.orm import (
8+
DeclarativeBase,
9+
Mapped,
10+
MappedAsDataclass,
11+
mapped_column,
12+
)
13+
from starlette.applications import Starlette
14+
15+
from sqladmin import Admin
16+
from sqladmin.models import ModelView
17+
from tests.common import async_engine as engine
18+
19+
pytestmark = pytest.mark.anyio
20+
session_maker = async_sessionmaker(
21+
bind=engine,
22+
class_=AsyncSession,
23+
expire_on_commit=False,
24+
)
25+
26+
27+
class Base(MappedAsDataclass, DeclarativeBase):
28+
pass
29+
30+
31+
class User(Base):
32+
__tablename__ = "users"
33+
34+
id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False)
35+
name: Mapped[str] = mapped_column(String(length=16), init=True)
36+
email: Mapped[str] = mapped_column(String, unique=True, nullable=True, init=False)
37+
38+
39+
class UserAdmin(ModelView, model=User):
40+
column_list = ["name", "email"]
41+
column_labels = {"name": "Name", "email": "Email"}
42+
43+
44+
app = Starlette()
45+
admin = Admin(app=app, engine=engine)
46+
admin.add_view(UserAdmin)
47+
48+
49+
@pytest.fixture(autouse=True)
50+
async def prepare_database() -> AsyncGenerator[None, None]:
51+
async with engine.begin() as conn:
52+
await conn.run_sync(Base.metadata.create_all)
53+
yield
54+
async with engine.begin() as conn:
55+
await conn.run_sync(Base.metadata.drop_all)
56+
57+
await engine.dispose()
58+
59+
60+
@pytest.fixture
61+
async def client() -> AsyncGenerator[AsyncClient, None]:
62+
transport = ASGITransport(app=app)
63+
async with AsyncClient(transport=transport, base_url="http://testserver") as c:
64+
yield c
65+
66+
67+
async def test_async_create_dataclass(client: AsyncClient) -> None:
68+
await client.post("/admin/user/create", data={"name": "foo", "email": "bar"})
69+
stmt = select(func.count(User.id))
70+
async with session_maker() as s:
71+
result = await s.execute(stmt)
72+
assert result.scalar_one() == 1
73+
74+
75+
async def test_update_dataclass(client: AsyncClient) -> None:
76+
async with session_maker() as session:
77+
user = User(name="John Doe")
78+
session.add(user)
79+
await session.commit()
80+
81+
await client.post("/admin/user/edit/1", data={"name": "foo", "email": "bar"})
82+
83+
stmt = select(User)
84+
async with session_maker() as s:
85+
result = await s.execute(stmt)
86+
user = result.scalar_one()
87+
assert user.name == "foo"
88+
assert user.email == "bar"
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from typing import Generator
2+
3+
import pytest
4+
from sqlalchemy import (
5+
Integer,
6+
String,
7+
func,
8+
select,
9+
)
10+
from sqlalchemy.orm import (
11+
DeclarativeBase,
12+
Mapped,
13+
MappedAsDataclass,
14+
mapped_column,
15+
sessionmaker,
16+
)
17+
from starlette.applications import Starlette
18+
from starlette.testclient import TestClient
19+
20+
from sqladmin import Admin, ModelView
21+
from tests.common import sync_engine as engine
22+
23+
session_maker = sessionmaker(bind=engine)
24+
25+
26+
class Base(MappedAsDataclass, DeclarativeBase):
27+
pass
28+
29+
30+
class User(Base):
31+
__tablename__ = "users"
32+
33+
id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False)
34+
name: Mapped[str] = mapped_column(String(length=16), init=True)
35+
email: Mapped[str] = mapped_column(String, unique=True, nullable=True, init=False)
36+
37+
38+
class UserAdmin(ModelView, model=User):
39+
column_list = ["name", "email"]
40+
column_labels = {"name": "Name", "email": "Email"}
41+
42+
43+
app = Starlette()
44+
admin = Admin(app=app, engine=engine)
45+
admin.add_model_view(UserAdmin)
46+
47+
48+
@pytest.fixture(autouse=True)
49+
def prepare_database() -> Generator[None, None, None]:
50+
Base.metadata.create_all(engine)
51+
yield
52+
Base.metadata.drop_all(engine)
53+
54+
55+
@pytest.fixture
56+
def client() -> Generator[TestClient, None, None]:
57+
with TestClient(app=app, base_url="http://testserver") as c:
58+
yield c
59+
60+
61+
def test_sync_create_dataclass(client: TestClient) -> None:
62+
client.post("/admin/user/create", data={"name": "foo", "email": "bar"})
63+
stmt = select(func.count(User.id))
64+
with session_maker() as s:
65+
result = s.execute(stmt)
66+
assert result.scalar_one() == 1
67+
68+
69+
def test_update_dataclass(client: TestClient) -> None:
70+
with session_maker() as session:
71+
user = User(name="John Doe")
72+
session.add(user)
73+
session.commit()
74+
75+
client.post("/admin/user/edit/1", data={"name": "foo", "email": "bar"})
76+
77+
stmt = select(User)
78+
with session_maker() as s:
79+
user = s.execute(stmt).scalar_one()
80+
assert user.name == "foo"
81+
assert user.email == "bar"

0 commit comments

Comments
 (0)