mirror of https://gitlab.com/litecord/litecord.git
checks: add target_member_id to guild_perm_check
This commit is contained in:
parent
7374091dec
commit
6c1a73233b
|
|
@ -113,18 +113,66 @@ async def channel_check(
|
||||||
return ctype, owner_id
|
return ctype, owner_id
|
||||||
|
|
||||||
|
|
||||||
async def guild_perm_check(user_id, guild_id, permission: str):
|
async def guild_perm_check(
|
||||||
"""Check guild permissions for a user."""
|
user_id,
|
||||||
|
guild_id,
|
||||||
|
permission: str,
|
||||||
|
target_member_id: Optional[int] = None,
|
||||||
|
raise_err: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""Check guild permissions for a user.
|
||||||
|
|
||||||
|
Accepts optional target argument for actions that are done TO another member
|
||||||
|
in the guild."""
|
||||||
base_perms = await base_permissions(user_id, guild_id)
|
base_perms = await base_permissions(user_id, guild_id)
|
||||||
hasperm = getattr(base_perms.bits, permission)
|
hasperm = getattr(base_perms.bits, permission)
|
||||||
|
|
||||||
if not hasperm:
|
# if we have the PERM and there's a target member involved,
|
||||||
|
# check on the target's max(role.position), if its equal or greater,
|
||||||
|
# raise MissingPermissions
|
||||||
|
if hasperm and target_member_id:
|
||||||
|
|
||||||
|
# there is no internal function to fetch full role objects
|
||||||
|
# (likely because it would be too expensive to do it here),
|
||||||
|
# so instead do a raw sql query.
|
||||||
|
target_max_position = await app.db.fetchval(
|
||||||
|
"""
|
||||||
|
SELECT MAX(role.position)
|
||||||
|
FROM member_roles
|
||||||
|
JOIN roles ON roles.id = member_roles.role_id
|
||||||
|
WHERE member_roles.member_id = $1
|
||||||
|
""",
|
||||||
|
target_member_id,
|
||||||
|
)
|
||||||
|
user_max_position = await app.db.fetchval(
|
||||||
|
"""
|
||||||
|
SELECT MAX(role.position)
|
||||||
|
FROM member_roles
|
||||||
|
JOIN roles ON roles.id = member_roles.role_id
|
||||||
|
WHERE member_roles.member_id = $1
|
||||||
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert target_max_position is not None
|
||||||
|
assert user_max_position is not None
|
||||||
|
|
||||||
|
if user_max_position <= target_max_position:
|
||||||
|
hasperm = False
|
||||||
|
|
||||||
|
if not hasperm and raise_err:
|
||||||
raise MissingPermissions("Missing permissions.")
|
raise MissingPermissions("Missing permissions.")
|
||||||
|
|
||||||
return bool(hasperm)
|
return bool(hasperm)
|
||||||
|
|
||||||
|
|
||||||
async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True):
|
# TODO add target semantics
|
||||||
|
async def channel_perm_check(
|
||||||
|
user_id,
|
||||||
|
channel_id,
|
||||||
|
permission: str,
|
||||||
|
raise_err=True,
|
||||||
|
):
|
||||||
"""Check channel permissions for a user."""
|
"""Check channel permissions for a user."""
|
||||||
base_perms = await get_permissions(user_id, channel_id)
|
base_perms = await get_permissions(user_id, channel_id)
|
||||||
hasperm = getattr(base_perms.bits, permission)
|
hasperm = getattr(base_perms.bits, permission)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue