checks: add target_member_id to guild_perm_check

This commit is contained in:
Luna 2021-08-30 22:13:20 -03:00
parent 7374091dec
commit 6c1a73233b
1 changed files with 52 additions and 4 deletions

View File

@ -113,18 +113,66 @@ async def channel_check(
return ctype, owner_id return ctype, owner_id
async def guild_perm_check(user_id, guild_id, permission: str): async def guild_perm_check(
"""Check guild permissions for a user.""" user_id,
guild_id,
permission: str,
target_member_id: Optional[int] = None,
raise_err: bool = True,
) -> bool:
"""Check guild permissions for a user.
Accepts optional target argument for actions that are done TO another member
in the guild."""
base_perms = await base_permissions(user_id, guild_id) base_perms = await base_permissions(user_id, guild_id)
hasperm = getattr(base_perms.bits, permission) hasperm = getattr(base_perms.bits, permission)
if not hasperm: # if we have the PERM and there's a target member involved,
# check on the target's max(role.position), if its equal or greater,
# raise MissingPermissions
if hasperm and target_member_id:
# there is no internal function to fetch full role objects
# (likely because it would be too expensive to do it here),
# so instead do a raw sql query.
target_max_position = await app.db.fetchval(
"""
SELECT MAX(role.position)
FROM member_roles
JOIN roles ON roles.id = member_roles.role_id
WHERE member_roles.member_id = $1
""",
target_member_id,
)
user_max_position = await app.db.fetchval(
"""
SELECT MAX(role.position)
FROM member_roles
JOIN roles ON roles.id = member_roles.role_id
WHERE member_roles.member_id = $1
""",
user_id,
)
assert target_max_position is not None
assert user_max_position is not None
if user_max_position <= target_max_position:
hasperm = False
if not hasperm and raise_err:
raise MissingPermissions("Missing permissions.") raise MissingPermissions("Missing permissions.")
return bool(hasperm) return bool(hasperm)
async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True): # TODO add target semantics
async def channel_perm_check(
user_id,
channel_id,
permission: str,
raise_err=True,
):
"""Check channel permissions for a user.""" """Check channel permissions for a user."""
base_perms = await get_permissions(user_id, channel_id) base_perms = await get_permissions(user_id, channel_id)
hasperm = getattr(base_perms.bits, permission) hasperm = getattr(base_perms.bits, permission)