diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 689f86b..3632e16 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -113,6 +113,33 @@ async def channel_check( return ctype, owner_id +async def _max_role_position(guild_id, member_id) -> Optional[int]: + return await app.db.fetchval( + """ + SELECT MAX(role.position) + FROM member_roles + JOIN roles ON roles.id = member_roles.role_id + WHERE member_roles.guild_id = $1 AND + member_roles.member_id = $2 + """, + guild_id, + member_id, + ) + + +async def _validate_target_member(guild_id: int, user_id: int, target_member_id: int): + # there is no internal function to fetch full role objects + # (likely because it would be too expensive to do it here), + # so instead do a raw sql query. + target_max_position = await _max_role_position(guild_id, target_member_id) + user_max_position = await _max_role_position(guild_id, user_id) + + assert target_max_position is not None + assert user_max_position is not None + + return user_max_position > target_max_position + + async def guild_perm_check( user_id, guild_id, @@ -131,34 +158,7 @@ async def guild_perm_check( # check on the target's max(role.position), if its equal or greater, # raise MissingPermissions if hasperm and target_member_id: - - # there is no internal function to fetch full role objects - # (likely because it would be too expensive to do it here), - # so instead do a raw sql query. - target_max_position = await app.db.fetchval( - """ - SELECT MAX(role.position) - FROM member_roles - JOIN roles ON roles.id = member_roles.role_id - WHERE member_roles.member_id = $1 - """, - target_member_id, - ) - user_max_position = await app.db.fetchval( - """ - SELECT MAX(role.position) - FROM member_roles - JOIN roles ON roles.id = member_roles.role_id - WHERE member_roles.member_id = $1 - """, - user_id, - ) - - assert target_max_position is not None - assert user_max_position is not None - - if user_max_position <= target_max_position: - hasperm = False + hasperm = await _validate_target_member(guild_id, user_id, target_member_id) if not hasperm and raise_err: raise MissingPermissions("Missing permissions.")