|
5 | 5 | import logging |
6 | 6 | from random import randint |
7 | 7 | import re |
8 | | -from typing import Any, cast, Callable, Dict, List, Optional, Set |
| 8 | +from typing import Any, cast, Callable, Dict, List, Optional, Set, Tuple |
9 | 9 | import uuid |
10 | 10 |
|
11 | 11 | from passlib.context import CryptContext |
@@ -864,6 +864,63 @@ def get_token(session: Session, token: uuid.UUID) -> Token: |
864 | 864 | return token_object |
865 | 865 |
|
866 | 866 |
|
| 867 | +def get_current_user_extended( |
| 868 | + session: Session, token: uuid.UUID |
| 869 | +) -> Tuple[bool, data.UserExtendedResponse]: |
| 870 | + """ |
| 871 | + Extend user data with groups for authentication middleware. |
| 872 | + """ |
| 873 | + query = ( |
| 874 | + session.query(Token.active, User, Group, GroupUser) |
| 875 | + .join(User, User.id == Token.user_id) |
| 876 | + .join(GroupUser, GroupUser.user_id == User.id) |
| 877 | + .join(Group, Group.id == GroupUser.group_id) |
| 878 | + .filter(Token.id == token) |
| 879 | + .filter(GroupUser.user_id == User.id) |
| 880 | + .group_by(Token.active, User, Group, GroupUser) |
| 881 | + ) |
| 882 | + objects = query.all() |
| 883 | + if len(objects) == 0: |
| 884 | + raise TokenNotFound(f"Token not found with ID: {token}") |
| 885 | + |
| 886 | + active_token = objects[0][0] |
| 887 | + user = objects[0][1] |
| 888 | + groups = [] |
| 889 | + for object in objects: |
| 890 | + if object[3].user_id != user.id: |
| 891 | + logger.error( |
| 892 | + f"Unexpected group id: {object[2].id} fetched for user with id: {user.id}" |
| 893 | + ) |
| 894 | + raise Exception("Unexpected group in list") |
| 895 | + groups.append( |
| 896 | + data.GroupUserResponse( |
| 897 | + group_id=object[2].id, |
| 898 | + user_id=object[3].user_id, |
| 899 | + user_type=object[3].user_type, |
| 900 | + autogenerated=object[2].autogenerated, |
| 901 | + group_name=object[2].name, |
| 902 | + parent=object[2].parent, |
| 903 | + ) |
| 904 | + ) |
| 905 | + |
| 906 | + user_extended = data.UserExtendedResponse( |
| 907 | + user_id=user.id, |
| 908 | + username=user.username, |
| 909 | + first_name=user.first_name, |
| 910 | + last_name=user.last_name, |
| 911 | + email=user.email, |
| 912 | + normalized_email=user.normalized_email, |
| 913 | + verified=user.verified, |
| 914 | + created_at=user.created_at, |
| 915 | + updated_at=user.updated_at, |
| 916 | + autogenerated=user.autogenerated, |
| 917 | + application_id=user.application_id, |
| 918 | + groups=groups, |
| 919 | + ) |
| 920 | + |
| 921 | + return active_token, user_extended |
| 922 | + |
| 923 | + |
867 | 924 | def update_token( |
868 | 925 | session: Session, |
869 | 926 | token: uuid.UUID, |
|
0 commit comments