From 6c1a73233b95ce743921f96f2b8124543054d0e2 Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 22:13:20 -0300 Subject: [PATCH 01/10] checks: add target_member_id to guild_perm_check --- litecord/blueprints/checks.py | 56 ++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 52ab5bc..689f86b 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -113,18 +113,66 @@ async def channel_check( return ctype, owner_id -async def guild_perm_check(user_id, guild_id, permission: str): - """Check guild permissions for a user.""" +async def guild_perm_check( + user_id, + guild_id, + permission: str, + target_member_id: Optional[int] = None, + raise_err: bool = True, +) -> bool: + """Check guild permissions for a user. + + Accepts optional target argument for actions that are done TO another member + in the guild.""" base_perms = await base_permissions(user_id, guild_id) hasperm = getattr(base_perms.bits, permission) - if not hasperm: + # if we have the PERM and there's a target member involved, + # check on the target's max(role.position), if its equal or greater, + # raise MissingPermissions + if hasperm and target_member_id: + + # there is no internal function to fetch full role objects + # (likely because it would be too expensive to do it here), + # so instead do a raw sql query. + target_max_position = await app.db.fetchval( + """ + SELECT MAX(role.position) + FROM member_roles + JOIN roles ON roles.id = member_roles.role_id + WHERE member_roles.member_id = $1 + """, + target_member_id, + ) + user_max_position = await app.db.fetchval( + """ + SELECT MAX(role.position) + FROM member_roles + JOIN roles ON roles.id = member_roles.role_id + WHERE member_roles.member_id = $1 + """, + user_id, + ) + + assert target_max_position is not None + assert user_max_position is not None + + if user_max_position <= target_max_position: + hasperm = False + + if not hasperm and raise_err: raise MissingPermissions("Missing permissions.") return bool(hasperm) -async def channel_perm_check(user_id, channel_id, permission: str, raise_err=True): +# TODO add target semantics +async def channel_perm_check( + user_id, + channel_id, + permission: str, + raise_err=True, +): """Check channel permissions for a user.""" base_perms = await get_permissions(user_id, channel_id) hasperm = getattr(base_perms.bits, permission) From e2b150ee21d717613b17a86fbc7b655510ad4fb3 Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 22:43:11 -0300 Subject: [PATCH 02/10] cheks: refactor target member --- litecord/blueprints/checks.py | 56 +++++++++++++++++------------------ 1 file changed, 28 insertions(+), 28 deletions(-) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 689f86b..3632e16 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -113,6 +113,33 @@ async def channel_check( return ctype, owner_id +async def _max_role_position(guild_id, member_id) -> Optional[int]: + return await app.db.fetchval( + """ + SELECT MAX(role.position) + FROM member_roles + JOIN roles ON roles.id = member_roles.role_id + WHERE member_roles.guild_id = $1 AND + member_roles.member_id = $2 + """, + guild_id, + member_id, + ) + + +async def _validate_target_member(guild_id: int, user_id: int, target_member_id: int): + # there is no internal function to fetch full role objects + # (likely because it would be too expensive to do it here), + # so instead do a raw sql query. + target_max_position = await _max_role_position(guild_id, target_member_id) + user_max_position = await _max_role_position(guild_id, user_id) + + assert target_max_position is not None + assert user_max_position is not None + + return user_max_position > target_max_position + + async def guild_perm_check( user_id, guild_id, @@ -131,34 +158,7 @@ async def guild_perm_check( # check on the target's max(role.position), if its equal or greater, # raise MissingPermissions if hasperm and target_member_id: - - # there is no internal function to fetch full role objects - # (likely because it would be too expensive to do it here), - # so instead do a raw sql query. - target_max_position = await app.db.fetchval( - """ - SELECT MAX(role.position) - FROM member_roles - JOIN roles ON roles.id = member_roles.role_id - WHERE member_roles.member_id = $1 - """, - target_member_id, - ) - user_max_position = await app.db.fetchval( - """ - SELECT MAX(role.position) - FROM member_roles - JOIN roles ON roles.id = member_roles.role_id - WHERE member_roles.member_id = $1 - """, - user_id, - ) - - assert target_max_position is not None - assert user_max_position is not None - - if user_max_position <= target_max_position: - hasperm = False + hasperm = await _validate_target_member(guild_id, user_id, target_member_id) if not hasperm and raise_err: raise MissingPermissions("Missing permissions.") From cb23d972c60b64fb7ad8aede86006f0489edda53 Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 22:51:53 -0300 Subject: [PATCH 03/10] add targets to guild_perm_check calls --- litecord/blueprints/guild/members.py | 2 +- litecord/blueprints/guild/mod.py | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py index e6ec42b..0f77dc8 100644 --- a/litecord/blueprints/guild/members.py +++ b/litecord/blueprints/guild/members.py @@ -129,7 +129,7 @@ async def modify_guild_member(guild_id, member_id): nick_flag = False if "nick" in j: - await guild_perm_check(user_id, guild_id, "manage_nicknames") + await guild_perm_check(user_id, guild_id, "manage_nicknames", member_id) nick = j["nick"] or None diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py index f108956..8084e72 100644 --- a/litecord/blueprints/guild/mod.py +++ b/litecord/blueprints/guild/mod.py @@ -23,7 +23,7 @@ from litecord.blueprints.auth import token_check from litecord.blueprints.checks import guild_perm_check from litecord.schemas import validate, GUILD_PRUNE -from litecord.common.guilds import remove_member, remove_member_multi +from litecord.common.guilds import remove_member bp = Blueprint("guild_moderation", __name__) @@ -32,8 +32,7 @@ bp = Blueprint("guild_moderation", __name__) async def kick_guild_member(guild_id, member_id): """Remove a member from a guild.""" user_id = await token_check() - - await guild_perm_check(user_id, guild_id, "kick_members") + await guild_perm_check(user_id, guild_id, "kick_members", member_id) await remove_member(guild_id, member_id) return "", 204 @@ -70,7 +69,7 @@ async def get_bans(guild_id): async def create_ban(guild_id, member_id): user_id = await token_check() - await guild_perm_check(user_id, guild_id, "ban_members") + await guild_perm_check(user_id, guild_id, "ban_members", member_id) j = await request.get_json() @@ -179,6 +178,19 @@ async def get_guild_prune_count(guild_id): return jsonify({"pruned": len(member_ids)}) +async def prune_members(user_id, guild_id, member_ids): + # calculate permissions against each pruned member, don't prune + # if permissions don't allow it + for member_id in member_ids: + has_permissions = await guild_perm_check( + user_id, guild_id, "kick_members", member_id, raise_err=False + ) + if not has_permissions: + continue + + await remove_member(guild_id, member_id) + + @bp.route("//prune", methods=["POST"]) async def begin_guild_prune(guild_id): user_id = await token_check() @@ -189,5 +201,5 @@ async def begin_guild_prune(guild_id): days = j["days"] member_ids = await get_prune(guild_id, days) - app.sched.spawn(remove_member_multi(guild_id, member_ids)) + app.sched.spawn(prune_members(user_id, guild_id, member_ids)) return jsonify({"pruned": len(member_ids)}) From 4c2bbe89a1c04cc792d35fc0eb582f914ea320e4 Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 22:59:33 -0300 Subject: [PATCH 04/10] remove unused function --- litecord/common/guilds.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py index 9212c2b..3c8484f 100644 --- a/litecord/common/guilds.py +++ b/litecord/common/guilds.py @@ -61,12 +61,6 @@ async def remove_member(guild_id: int, member_id: int): ) -async def remove_member_multi(guild_id: int, members: list): - """Remove multiple members.""" - for member_id in members: - await remove_member(guild_id, member_id) - - async def create_role(guild_id, name: str, **kwargs): """Create a role in a guild.""" new_role_id = app.winter_factory.snowflake() From 0c14473a9588d2b93a7ff870684692974caa0103 Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 23:18:28 -0300 Subject: [PATCH 05/10] checks: validate when target or user is an owner --- litecord/blueprints/checks.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 3632e16..f6b43b2 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -127,10 +127,31 @@ async def _max_role_position(guild_id, member_id) -> Optional[int]: ) -async def _validate_target_member(guild_id: int, user_id: int, target_member_id: int): +async def _validate_target_member( + guild_id: int, user_id: int, target_member_id: int +) -> bool: + owner_id = await storage.db.fetchval( + """ + SELECT owner_id + FROM guilds + WHERE id = $1 + """, + guild_id, + ) + + # owners have all permissions + # if doing an action as an owner, it always works + # if doing an action TO an owner, it always fails + if user_id == owner_id: + return True + + if target_member_id == owner_id: + return False + # there is no internal function to fetch full role objects # (likely because it would be too expensive to do it here), # so instead do a raw sql query. + target_max_position = await _max_role_position(guild_id, target_member_id) user_max_position = await _max_role_position(guild_id, user_id) From 36a69ad8cba6203442806aac5ed3179e1ae9c116 Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 23:22:54 -0300 Subject: [PATCH 06/10] fix typo --- litecord/blueprints/checks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index f6b43b2..3571a0e 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -130,7 +130,7 @@ async def _max_role_position(guild_id, member_id) -> Optional[int]: async def _validate_target_member( guild_id: int, user_id: int, target_member_id: int ) -> bool: - owner_id = await storage.db.fetchval( + owner_id = await app.storage.db.fetchval( """ SELECT owner_id FROM guilds @@ -138,6 +138,7 @@ async def _validate_target_member( """, guild_id, ) + assert owner_id is not None # owners have all permissions # if doing an action as an owner, it always works From 70cd40966d9526c3db2872f559b4550790a8320e Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 23:37:02 -0300 Subject: [PATCH 07/10] fix typo on sql query --- litecord/blueprints/checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 3571a0e..555ad77 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -116,7 +116,7 @@ async def channel_check( async def _max_role_position(guild_id, member_id) -> Optional[int]: return await app.db.fetchval( """ - SELECT MAX(role.position) + SELECT MAX(roles.position) FROM member_roles JOIN roles ON roles.id = member_roles.role_id WHERE member_roles.guild_id = $1 AND From 43e34cde3aa1441b4a6f2f3b7235a1b34e6019b3 Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 23:39:53 -0300 Subject: [PATCH 08/10] fix typo in sql statement --- litecord/blueprints/checks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 555ad77..90d67e7 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -120,7 +120,7 @@ async def _max_role_position(guild_id, member_id) -> Optional[int]: FROM member_roles JOIN roles ON roles.id = member_roles.role_id WHERE member_roles.guild_id = $1 AND - member_roles.member_id = $2 + member_roles.user_id = $2 """, guild_id, member_id, From 18190751373821decba70f5408acf0316595342f Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 23:52:35 -0300 Subject: [PATCH 09/10] properly unsubscribe member when being removed --- litecord/common/guilds.py | 6 +++++- litecord/pubsub/guild.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py index 3c8484f..479eb84 100644 --- a/litecord/common/guilds.py +++ b/litecord/common/guilds.py @@ -50,7 +50,11 @@ async def remove_member(guild_id: int, member_id: int): user = await app.storage.get_user(member_id) - await app.dispatcher.guild.unsub(guild_id, member_id) + states, channels = await app.dispatcher.guild.unsub_user(guild_id, user_id) + for channel_id in channels: + for state in states: + await app.dispatcher.channel.unsub(channel_id, state.session_id) + await app.lazy_guild.remove_member(guild_id, user["id"]) await app.dispatcher.guild.dispatch( guild_id, diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index 73e363e..32a3e6d 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -71,6 +71,16 @@ class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]): return states, channel_ids + async def unsub_user( + self, guild_id: int, user_id: int + ) -> Tuple[List[GatewayState], List[int]]: + states = app.state_manager.fetch_states(user_id, guild_id) + for state in states: + await self.unsub(guild_id, state.session_id) + + guild_chan_ids = await app.storage.get_channel_ids(guild_id) + return states, guild_chan_ids + async def dispatch_filter( self, guild_id: int, filter_function, event: GatewayEvent ): From dd10d7f99667e2087a6e48712f60cd5a9ae8b59e Mon Sep 17 00:00:00 2001 From: Luna Date: Mon, 30 Aug 2021 23:54:13 -0300 Subject: [PATCH 10/10] s/user_id/member_id --- litecord/common/guilds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py index 479eb84..c458f6e 100644 --- a/litecord/common/guilds.py +++ b/litecord/common/guilds.py @@ -50,7 +50,7 @@ async def remove_member(guild_id: int, member_id: int): user = await app.storage.get_user(member_id) - states, channels = await app.dispatcher.guild.unsub_user(guild_id, user_id) + states, channels = await app.dispatcher.guild.unsub_user(guild_id, member_id) for channel_id in channels: for state in states: await app.dispatcher.channel.unsub(channel_id, state.session_id)