diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 52ab5bc..90d67e7 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -113,18 +113,88 @@ async def channel_check( return ctype, owner_id -async def guild_perm_check(user_id, guild_id, permission: str): - """Check guild permissions for a user.""" +async def _max_role_position(guild_id, member_id) -> Optional[int]: + return await app.db.fetchval( + """ + SELECT MAX(roles.position) + FROM member_roles + JOIN roles ON roles.id = member_roles.role_id + WHERE member_roles.guild_id = $1 AND + member_roles.user_id = $2 + """, + guild_id, + member_id, + ) + + +async def _validate_target_member( + guild_id: int, user_id: int, target_member_id: int +) -> bool: + owner_id = await app.storage.db.fetchval( + """ + SELECT owner_id + FROM guilds + WHERE id = $1 + """, + guild_id, + ) + assert owner_id is not None + + # owners have all permissions + # if doing an action as an owner, it always works + # if doing an action TO an owner, it always fails + if user_id == owner_id: + return True + + if target_member_id == owner_id: + return False + + # there is no internal function to fetch full role objects + # (likely because it would be too expensive to do it here), + # so instead do a raw sql query. + + target_max_position = await _max_role_position(guild_id, target_member_id) + user_max_position = await _max_role_position(guild_id, user_id) + + assert target_max_position is not None + assert user_max_position is not None + + return user_max_position > target_max_position + + +async def guild_perm_check( + user_id, + guild_id, + permission: str, + target_member_id: Optional[int] = None, + raise_err: bool = True, +) -> bool: + """Check guild permissions for a user. + + Accepts optional target argument for actions that are done TO another member + in the guild.""" base_perms = await base_permissions(user_id, guild_id) hasperm = getattr(base_perms.bits, permission) - if not hasperm: + # if we have the PERM and there's a target member involved, + # check on the target's max(role.position), if its equal or greater, + # raise MissingPermissions + if hasperm and target_member_id: + hasperm = await _validate_target_member(guild_id, user_id, target_member_id) + + if not hasperm and raise_err: raise MissingPermissions("Missing permissions.") return bool(hasperm) -async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True): +# TODO add target semantics +async def channel_perm_check( + user_id, + channel_id, + permission: str, + raise_err=True, +): """Check channel permissions for a user.""" base_perms = await get_permissions(user_id, channel_id) hasperm = getattr(base_perms.bits, permission) diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py index e6ec42b..0f77dc8 100644 --- a/litecord/blueprints/guild/members.py +++ b/litecord/blueprints/guild/members.py @@ -129,7 +129,7 @@ async def modify_guild_member(guild_id, member_id): nick_flag = False if "nick" in j: - await guild_perm_check(user_id, guild_id, "manage_nicknames") + await guild_perm_check(user_id, guild_id, "manage_nicknames", member_id) nick = j["nick"] or None diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py index f108956..8084e72 100644 --- a/litecord/blueprints/guild/mod.py +++ b/litecord/blueprints/guild/mod.py @@ -23,7 +23,7 @@ from litecord.blueprints.auth import token_check from litecord.blueprints.checks import guild_perm_check from litecord.schemas import validate, GUILD_PRUNE -from litecord.common.guilds import remove_member, remove_member_multi +from litecord.common.guilds import remove_member bp = Blueprint("guild_moderation", __name__) @@ -32,8 +32,7 @@ bp = Blueprint("guild_moderation", __name__) async def kick_guild_member(guild_id, member_id): """Remove a member from a guild.""" user_id = await token_check() - - await guild_perm_check(user_id, guild_id, "kick_members") + await guild_perm_check(user_id, guild_id, "kick_members", member_id) await remove_member(guild_id, member_id) return "", 204 @@ -70,7 +69,7 @@ async def get_bans(guild_id): async def create_ban(guild_id, member_id): user_id = await token_check() - await guild_perm_check(user_id, guild_id, "ban_members") + await guild_perm_check(user_id, guild_id, "ban_members", member_id) j = await request.get_json() @@ -179,6 +178,19 @@ async def get_guild_prune_count(guild_id): return jsonify({"pruned": len(member_ids)}) +async def prune_members(user_id, guild_id, member_ids): + # calculate permissions against each pruned member, don't prune + # if permissions don't allow it + for member_id in member_ids: + has_permissions = await guild_perm_check( + user_id, guild_id, "kick_members", member_id, raise_err=False + ) + if not has_permissions: + continue + + await remove_member(guild_id, member_id) + + @bp.route("//prune", methods=["POST"]) async def begin_guild_prune(guild_id): user_id = await token_check() @@ -189,5 +201,5 @@ async def begin_guild_prune(guild_id): days = j["days"] member_ids = await get_prune(guild_id, days) - app.sched.spawn(remove_member_multi(guild_id, member_ids)) + app.sched.spawn(prune_members(user_id, guild_id, member_ids)) return jsonify({"pruned": len(member_ids)}) diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py index 9212c2b..c458f6e 100644 --- a/litecord/common/guilds.py +++ b/litecord/common/guilds.py @@ -50,7 +50,11 @@ async def remove_member(guild_id: int, member_id: int): user = await app.storage.get_user(member_id) - await app.dispatcher.guild.unsub(guild_id, member_id) + states, channels = await app.dispatcher.guild.unsub_user(guild_id, member_id) + for channel_id in channels: + for state in states: + await app.dispatcher.channel.unsub(channel_id, state.session_id) + await app.lazy_guild.remove_member(guild_id, user["id"]) await app.dispatcher.guild.dispatch( guild_id, @@ -61,12 +65,6 @@ async def remove_member(guild_id: int, member_id: int): ) -async def remove_member_multi(guild_id: int, members: list): - """Remove multiple members.""" - for member_id in members: - await remove_member(guild_id, member_id) - - async def create_role(guild_id, name: str, **kwargs): """Create a role in a guild.""" new_role_id = app.winter_factory.snowflake() diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index 73e363e..32a3e6d 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -71,6 +71,16 @@ class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]): return states, channel_ids + async def unsub_user( + self, guild_id: int, user_id: int + ) -> Tuple[List[GatewayState], List[int]]: + states = app.state_manager.fetch_states(user_id, guild_id) + for state in states: + await self.unsub(guild_id, state.session_id) + + guild_chan_ids = await app.storage.get_channel_ids(guild_id) + return states, guild_chan_ids + async def dispatch_filter( self, guild_id: int, filter_function, event: GatewayEvent ):