Skip to content

Commit c2a157c

Browse files
committed
test: Add pressure test
1 parent fdf7285 commit c2a157c

2 files changed

Lines changed: 49 additions & 16 deletions

File tree

tests/test_AsyncDatabase.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import asyncio
12
import datetime
3+
from asyncio import AbstractEventLoop
24
from contextlib import asynccontextmanager
35
from typing import AsyncGenerator, List
46

@@ -174,7 +176,13 @@ async def test_sqlmodel_session(fake_users):
174176
assert user.id == 1
175177

176178

177-
async def test_async_session_context_var(fake_users):
179+
@pytest.fixture()
180+
def lock(event_loop: AbstractEventLoop):
181+
return asyncio.Lock(loop=event_loop)
182+
183+
184+
async def test_async_session_context_var(fake_users, lock, i=1):
185+
178186
async with db() as session:
179187
# test enter return session
180188
user = await session.get(User, 1)
@@ -194,14 +202,13 @@ async def test_async_session_context_var(fake_users):
194202
# test db function
195203
user = await db.get(User, 1)
196204
assert user.id == 1
197-
group = Group(name="group1")
205+
group = Group(name=f"group{i}")
198206
await db.save(group, refresh=True)
199-
assert group.id == 1
200-
user.group_id = group.id
201-
202-
await db.save(user, group, refresh=True)
203-
assert user.group_id == group.id
204-
assert user.group.name == "group1" # type: ignore
207+
async with lock: # test async concurrency safe, because the same user is operated here, so a lock is needed
208+
user.group_id = group.id
209+
await db.save(user, group, refresh=True)
210+
assert user.group_id == group.id
211+
assert user.group.name == f"group{i}" # type: ignore
205212

206213
user2 = await db.get(User, 2, options=[selectinload(User.group)])
207214
assert user2.group is None
@@ -214,3 +221,14 @@ async def test_async_session_context_var(fake_users):
214221
assert user.group is None if user.group_id is None else user.group
215222

216223
assert db.session is None
224+
return i
225+
226+
227+
def test_asyncio_groups(fake_users, event_loop: AbstractEventLoop, lock):
228+
task_count = 40
229+
tasks = [asyncio.ensure_future(test_async_session_context_var(fake_users, lock, i=i)) for i in range(task_count)]
230+
event_loop.run_until_complete(asyncio.wait(tasks))
231+
assert len(tasks) == task_count
232+
for task in tasks:
233+
assert task.result() is not None
234+
assert task.exception() is None

tests/test_Database.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import datetime
2+
import threading
3+
from concurrent.futures import ALL_COMPLETED, ThreadPoolExecutor, wait
24
from contextlib import contextmanager
35
from typing import Generator, List
46

@@ -172,7 +174,10 @@ def test_sqlmodel_session(fake_users):
172174
assert user.id == 1
173175

174176

175-
def test_session_context_var(fake_users):
177+
lock = threading.Lock()
178+
179+
180+
def test_session_context_var(fake_users, i=1):
176181
with db() as session:
177182
# test enter return session
178183
user = session.get(User, 1)
@@ -192,14 +197,14 @@ def test_session_context_var(fake_users):
192197
# test db function
193198
user = db.get(User, 1)
194199
assert user.id == 1
195-
group = Group(name="group1")
200+
group = Group(name=f"group{i}")
196201
db.save(group, refresh=True)
197-
assert group.id == 1
198-
user.group_id = group.id
199202

200-
db.save(user, refresh=True)
201-
assert user.group_id == group.id
202-
assert user.group.name == "group1" # type: ignore
203+
with lock: # test thread safe, because the same user is operated here, so a lock is needed
204+
user.group_id = group.id
205+
db.save(user, refresh=True)
206+
assert user.group_id == group.id
207+
assert user.group.name == f"group{i}" # type: ignore
203208

204209
user2 = db.get(User, 2)
205210
assert user2.group is None
@@ -210,5 +215,15 @@ def test_session_context_var(fake_users):
210215
users = db.scalars_all(select(User))
211216
for user in users:
212217
assert user.group is None if user.group_id is None else user.group
213-
214218
assert db.session is None
219+
return i
220+
221+
222+
def test_ThreadPoolExecutor(fake_users):
223+
task_count = 40
224+
pool = ThreadPoolExecutor(max_workers=20) # 创建线程池,设置最大线程数
225+
all_task = [pool.submit(test_session_context_var, fake_users, k) for k in range(task_count)] # 投递任务
226+
# print(all_task)
227+
done, fail = wait(all_task, return_when=ALL_COMPLETED) # 等待线程运行完毕
228+
results = {task.result() for task in done}
229+
assert len(results) == task_count

0 commit comments

Comments
 (0)