From 6c1a73233b95ce743921f96f2b8124543054d0e2 Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 22:13:20 -0300 Subject: [PATCH] checks: add target_member_id to guild_perm_check --- litecord/blueprints/checks.py | 56 ++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 52ab5bc..689f86b 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -113,18 +113,66 @@ async def channel_check( return ctype, owner_id -async def guild_perm_check(user_id, guild_id, permission: str): - """Check guild permissions for a user.""" +async def guild_perm_check( + user_id, + guild_id, + permission: str, + target_member_id: Optional[int] = None, + raise_err: bool = True, +) -> bool: + """Check guild permissions for a user. + + Accepts optional target argument for actions that are done TO another member + in the guild.""" base_perms = await base_permissions(user_id, guild_id) hasperm = getattr(base_perms.bits, permission) - if not hasperm: + # if we have the PERM and there's a target member involved, + # check on the target's max(role.position), if its equal or greater, + # raise MissingPermissions + if hasperm and target_member_id: + + # there is no internal function to fetch full role objects + # (likely because it would be too expensive to do it here), + # so instead do a raw sql query. + target_max_position = await app.db.fetchval( + """ + SELECT MAX(role.position) + FROM member_roles + JOIN roles ON roles.id = member_roles.role_id + WHERE member_roles.member_id = $1 + """, + target_member_id, + ) + user_max_position = await app.db.fetchval( + """ + SELECT MAX(role.position) + FROM member_roles + JOIN roles ON roles.id = member_roles.role_id + WHERE member_roles.member_id = $1 + """, + user_id, + ) + + assert target_max_position is not None + assert user_max_position is not None + + if user_max_position <= target_max_position: + hasperm = False + + if not hasperm and raise_err: raise MissingPermissions("Missing permissions.") return bool(hasperm) -async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True): +# TODO add target semantics +async def channel_perm_check( + user_id, + channel_id, + permission: str, + raise_err=True, +): """Check channel permissions for a user.""" base_perms = await get_permissions(user_id, channel_id) hasperm = getattr(base_perms.bits, permission)