|
| 1 | +from typing import AsyncGenerator |
| 2 | + |
| 3 | +import pytest |
| 4 | +from httpx import ASGITransport, AsyncClient |
| 5 | +from sqlalchemy import Integer, String, func, select |
| 6 | +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker |
| 7 | +from sqlalchemy.orm import ( |
| 8 | + DeclarativeBase, |
| 9 | + Mapped, |
| 10 | + MappedAsDataclass, |
| 11 | + mapped_column, |
| 12 | +) |
| 13 | +from starlette.applications import Starlette |
| 14 | + |
| 15 | +from sqladmin import Admin |
| 16 | +from sqladmin.models import ModelView |
| 17 | +from tests.common import async_engine as engine |
| 18 | + |
| 19 | +pytestmark = pytest.mark.anyio |
| 20 | +session_maker = async_sessionmaker( |
| 21 | + bind=engine, |
| 22 | + class_=AsyncSession, |
| 23 | + expire_on_commit=False, |
| 24 | +) |
| 25 | + |
| 26 | + |
| 27 | +class Base(MappedAsDataclass, DeclarativeBase): |
| 28 | + pass |
| 29 | + |
| 30 | + |
| 31 | +class User(Base): |
| 32 | + __tablename__ = "users" |
| 33 | + |
| 34 | + id: Mapped[int] = mapped_column(Integer, primary_key=True, init=False) |
| 35 | + name: Mapped[str] = mapped_column(String(length=16), init=True) |
| 36 | + email: Mapped[str] = mapped_column(String, unique=True, nullable=True, init=False) |
| 37 | + |
| 38 | + |
| 39 | +class UserAdmin(ModelView, model=User): |
| 40 | + column_list = ["name", "email"] |
| 41 | + column_labels = {"name": "Name", "email": "Email"} |
| 42 | + |
| 43 | + |
| 44 | +app = Starlette() |
| 45 | +admin = Admin(app=app, engine=engine) |
| 46 | +admin.add_view(UserAdmin) |
| 47 | + |
| 48 | + |
| 49 | +@pytest.fixture(autouse=True) |
| 50 | +async def prepare_database() -> AsyncGenerator[None, None]: |
| 51 | + async with engine.begin() as conn: |
| 52 | + await conn.run_sync(Base.metadata.create_all) |
| 53 | + yield |
| 54 | + async with engine.begin() as conn: |
| 55 | + await conn.run_sync(Base.metadata.drop_all) |
| 56 | + |
| 57 | + await engine.dispose() |
| 58 | + |
| 59 | + |
| 60 | +@pytest.fixture |
| 61 | +async def client() -> AsyncGenerator[AsyncClient, None]: |
| 62 | + transport = ASGITransport(app=app) |
| 63 | + async with AsyncClient(transport=transport, base_url="http://testserver") as c: |
| 64 | + yield c |
| 65 | + |
| 66 | + |
| 67 | +async def test_async_create_dataclass(client: AsyncClient) -> None: |
| 68 | + await client.post("/admin/user/create", data={"name": "foo", "email": "bar"}) |
| 69 | + stmt = select(func.count(User.id)) |
| 70 | + async with session_maker() as s: |
| 71 | + result = await s.execute(stmt) |
| 72 | + assert result.scalar_one() == 1 |
| 73 | + |
| 74 | + |
| 75 | +async def test_update_dataclass(client: AsyncClient) -> None: |
| 76 | + async with session_maker() as session: |
| 77 | + user = User(name="John Doe") |
| 78 | + session.add(user) |
| 79 | + await session.commit() |
| 80 | + |
| 81 | + await client.post("/admin/user/edit/1", data={"name": "foo", "email": "bar"}) |
| 82 | + |
| 83 | + stmt = select(User) |
| 84 | + async with session_maker() as s: |
| 85 | + result = await s.execute(stmt) |
| 86 | + user = result.scalar_one() |
| 87 | + assert user.name == "foo" |
| 88 | + assert user.email == "bar" |
0 commit comments