Skip to content

Commit ad6febd

Browse files
committed
Auth middleware to avoid multiple queries
1 parent 5813ad3 commit ad6febd

4 files changed

Lines changed: 92 additions & 1 deletion

File tree

brood/actions.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import logging
66
from random import randint
77
import re
8-
from typing import Any, cast, Callable, Dict, List, Optional, Set
8+
from typing import Any, cast, Callable, Dict, List, Optional, Set, Tuple
99
import uuid
1010

1111
from passlib.context import CryptContext
@@ -864,6 +864,63 @@ def get_token(session: Session, token: uuid.UUID) -> Token:
864864
return token_object
865865

866866

867+
def get_current_user_extended(
868+
session: Session, token: uuid.UUID
869+
) -> Tuple[bool, data.UserExtendedResponse]:
870+
"""
871+
Extend user data with groups for authentication middleware.
872+
"""
873+
query = (
874+
session.query(Token.active, User, Group, GroupUser)
875+
.join(User, User.id == Token.user_id)
876+
.join(GroupUser, GroupUser.user_id == User.id)
877+
.join(Group, Group.id == GroupUser.group_id)
878+
.filter(Token.id == token)
879+
.filter(GroupUser.user_id == User.id)
880+
.group_by(Token.active, User, Group, GroupUser)
881+
)
882+
objects = query.all()
883+
if len(objects) == 0:
884+
raise TokenNotFound(f"Token not found with ID: {token}")
885+
886+
active_token = objects[0][0]
887+
user = objects[0][1]
888+
groups = []
889+
for object in objects:
890+
if object[3].user_id != user.id:
891+
logger.error(
892+
f"Unexpected group id: {object[2].id} fetched for user with id: {user.id}"
893+
)
894+
raise Exception("Unexpected group in list")
895+
groups.append(
896+
data.GroupUserResponse(
897+
group_id=object[2].id,
898+
user_id=object[3].user_id,
899+
user_type=object[3].user_type,
900+
autogenerated=object[2].autogenerated,
901+
group_name=object[2].name,
902+
parent=object[2].parent,
903+
)
904+
)
905+
906+
user_extended = data.UserExtendedResponse(
907+
user_id=user.id,
908+
username=user.username,
909+
first_name=user.first_name,
910+
last_name=user.last_name,
911+
email=user.email,
912+
normalized_email=user.normalized_email,
913+
verified=user.verified,
914+
created_at=user.created_at,
915+
updated_at=user.updated_at,
916+
autogenerated=user.autogenerated,
917+
application_id=user.application_id,
918+
groups=groups,
919+
)
920+
921+
return active_token, user_extended
922+
923+
867924
def update_token(
868925
session: Session,
869926
token: uuid.UUID,

brood/api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
oauth2_scheme,
3131
autogenerated_user_token_check,
3232
get_current_user,
33+
current_user_extended,
3334
is_token_restricted,
3435
is_token_restricted_or_installation,
3536
get_current_user_or_installation,
@@ -461,6 +462,16 @@ async def get_user_by_id_handler(
461462
return user
462463

463464

465+
@app.get("/auth", tags=["users"], response_model=data.UserExtendedResponse)
466+
async def get_auth_handler(
467+
current_user: data.UserExtendedResponse = Depends(current_user_extended),
468+
) -> data.UserExtendedResponse:
469+
"""
470+
Authorization middleware.
471+
"""
472+
return current_user
473+
474+
464475
@app.post("/confirm", tags=["users"], response_model=data.UserResponse)
465476
async def verification_handler(
466477
token_restricted: bool = Depends(is_token_restricted),

brood/data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,3 +273,7 @@ class ApplicationResponse(BaseModel):
273273

274274
class ApplicationsListResponse(BaseModel):
275275
applications: List[ApplicationResponse] = Field(default_factory=list)
276+
277+
278+
class UserExtendedResponse(UserResponse):
279+
groups: List[GroupUserResponse] = Field(default_factory=list)

brood/middleware.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from fastapi.security import OAuth2PasswordBearer
1010

1111
from . import actions
12+
from . import data
1213
from . import models
1314
from .external import yield_db_session_from_env
1415
from .settings import BOT_INSTALLATION_TOKEN, BOT_INSTALLATION_TOKEN_HEADER
@@ -32,6 +33,24 @@ async def get_current_user(
3233
return token_object.user
3334

3435

36+
async def current_user_extended(
37+
token: UUID = Depends(oauth2_scheme),
38+
db_session=Depends(yield_db_session_from_env),
39+
) -> data.UserExtendedResponse:
40+
try:
41+
token_active, user_extended = actions.get_current_user_extended(
42+
session=db_session, token=token
43+
)
44+
except actions.TokenNotFound:
45+
raise HTTPException(status_code=404, detail="Access token not found")
46+
except Exception:
47+
raise HTTPException(status_code=500)
48+
if not token_active:
49+
raise HTTPException(status_code=403, detail="Token has expired")
50+
51+
return user_extended
52+
53+
3554
def autogenerated_user_token_check(request: Request) -> bool:
3655
if BOT_INSTALLATION_TOKEN is None:
3756
raise ValueError("BOT_INSTALLATION_TOKEN environment variable must be set")

0 commit comments

Comments
 (0)