1+ import asyncio
12import datetime
3+ from asyncio import AbstractEventLoop
24from contextlib import asynccontextmanager
35from typing import AsyncGenerator , List
46
@@ -174,7 +176,13 @@ async def test_sqlmodel_session(fake_users):
174176 assert user .id == 1
175177
176178
177- async def test_async_session_context_var (fake_users ):
179+ @pytest .fixture ()
180+ def lock (event_loop : AbstractEventLoop ):
181+ return asyncio .Lock (loop = event_loop )
182+
183+
184+ async def test_async_session_context_var (fake_users , lock , i = 1 ):
185+
178186 async with db () as session :
179187 # test enter return session
180188 user = await session .get (User , 1 )
@@ -194,14 +202,13 @@ async def test_async_session_context_var(fake_users):
194202 # test db function
195203 user = await db .get (User , 1 )
196204 assert user .id == 1
197- group = Group (name = "group1 " )
205+ group = Group (name = f"group { i } " )
198206 await db .save (group , refresh = True )
199- assert group .id == 1
200- user .group_id = group .id
201-
202- await db .save (user , group , refresh = True )
203- assert user .group_id == group .id
204- assert user .group .name == "group1" # type: ignore
207+ async with lock : # test async concurrency safe, because the same user is operated here, so a lock is needed
208+ user .group_id = group .id
209+ await db .save (user , group , refresh = True )
210+ assert user .group_id == group .id
211+ assert user .group .name == f"group{ i } " # type: ignore
205212
206213 user2 = await db .get (User , 2 , options = [selectinload (User .group )])
207214 assert user2 .group is None
@@ -214,3 +221,14 @@ async def test_async_session_context_var(fake_users):
214221 assert user .group is None if user .group_id is None else user .group
215222
216223 assert db .session is None
224+ return i
225+
226+
227+ def test_asyncio_groups (fake_users , event_loop : AbstractEventLoop , lock ):
228+ task_count = 40
229+ tasks = [asyncio .ensure_future (test_async_session_context_var (fake_users , lock , i = i )) for i in range (task_count )]
230+ event_loop .run_until_complete (asyncio .wait (tasks ))
231+ assert len (tasks ) == task_count
232+ for task in tasks :
233+ assert task .result () is not None
234+ assert task .exception () is None
0 commit comments