mirror of https://gitlab.com/litecord/litecord.git
checks: add target_member_id to guild_perm_check
This commit is contained in:
parent
7374091dec
commit
6c1a73233b
|
|
@ -113,18 +113,66 @@ async def channel_check(
|
|||
return ctype, owner_id
|
||||
|
||||
|
||||
async def guild_perm_check(user_id, guild_id, permission: str):
|
||||
"""Check guild permissions for a user."""
|
||||
async def guild_perm_check(
|
||||
user_id,
|
||||
guild_id,
|
||||
permission: str,
|
||||
target_member_id: Optional[int] = None,
|
||||
raise_err: bool = True,
|
||||
) -> bool:
|
||||
"""Check guild permissions for a user.
|
||||
|
||||
Accepts optional target argument for actions that are done TO another member
|
||||
in the guild."""
|
||||
base_perms = await base_permissions(user_id, guild_id)
|
||||
hasperm = getattr(base_perms.bits, permission)
|
||||
|
||||
if not hasperm:
|
||||
# if we have the PERM and there's a target member involved,
|
||||
# check on the target's max(role.position), if its equal or greater,
|
||||
# raise MissingPermissions
|
||||
if hasperm and target_member_id:
|
||||
|
||||
# there is no internal function to fetch full role objects
|
||||
# (likely because it would be too expensive to do it here),
|
||||
# so instead do a raw sql query.
|
||||
target_max_position = await app.db.fetchval(
|
||||
"""
|
||||
SELECT MAX(role.position)
|
||||
FROM member_roles
|
||||
JOIN roles ON roles.id = member_roles.role_id
|
||||
WHERE member_roles.member_id = $1
|
||||
""",
|
||||
target_member_id,
|
||||
)
|
||||
user_max_position = await app.db.fetchval(
|
||||
"""
|
||||
SELECT MAX(role.position)
|
||||
FROM member_roles
|
||||
JOIN roles ON roles.id = member_roles.role_id
|
||||
WHERE member_roles.member_id = $1
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
assert target_max_position is not None
|
||||
assert user_max_position is not None
|
||||
|
||||
if user_max_position <= target_max_position:
|
||||
hasperm = False
|
||||
|
||||
if not hasperm and raise_err:
|
||||
raise MissingPermissions("Missing permissions.")
|
||||
|
||||
return bool(hasperm)
|
||||
|
||||
|
||||
async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True):
|
||||
# TODO add target semantics
|
||||
async def channel_perm_check(
|
||||
user_id,
|
||||
channel_id,
|
||||
permission: str,
|
||||
raise_err=True,
|
||||
):
|
||||
"""Check channel permissions for a user."""
|
||||
base_perms = await get_permissions(user_id, channel_id)
|
||||
hasperm = getattr(base_perms.bits, permission)
|
||||
|
|
|
|||
Loading…
Reference in New Issue