diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py index cb2898c..ed7c048 100644 --- a/litecord/blueprints/invites.py +++ b/litecord/blueprints/invites.py @@ -8,10 +8,13 @@ from ..auth import token_check from ..schemas import validate, INVITE from ..enums import ChannelType from ..errors import BadRequest, Forbidden -from .channels import channel_check -from .guilds import guild_check, create_guild_settings +from .guilds import create_guild_settings from ..utils import async_map +from litecord.blueprints.checks import ( + channel_check, channel_perm_check, guild_check, guild_perm_check +) + log = Logger(__name__) bp = Blueprint('invites', __name__) @@ -23,7 +26,7 @@ async def create_invite(channel_id): j = validate(await request.get_json(), INVITE) _ctype, guild_id = await channel_check(user_id, channel_id) - # TODO: check CREATE_INSTANT_INVITE permission + await channel_perm_check(user_id, channel_id, 'create_invites') chantype = await app.storage.get_chan_type(channel_id) if chantype not in (ChannelType.GUILD_TEXT.value, @@ -70,26 +73,16 @@ async def get_invite_2(invite_code: str): async def delete_invite(invite_code: str): user_id = await token_check() - gid = await app.db.fetchval(""" + guild_id = await app.db.fetchval(""" SELECT guild_id FROM invites WHERE code = $1 """, invite_code) - if gid is None: + if guild_id is None: raise BadRequest('Unknown invite') - # TODO: check MANAGE_CHANNELS permission - # for now we'll go with checking owner - - owner_id = await app.db.fetchval(""" - SELECT owner_id - FROM guilds - WHERE id = $1 - """. gid) - - if owner_id != user_id: - raise Forbidden('Not guild owner') + await guild_perm_check(user_id, guild_id, 'manage_channels') inv = await app.storage.get_invite(invite_code) @@ -109,8 +102,11 @@ async def _get_inv(code): @bp.route('/guilds//invites', methods=['GET']) async def get_guild_invites(guild_id: int): + """Get all invites for a guild.""" user_id = await token_check() + await guild_check(user_id, guild_id) + await guild_perm_check(user_id, guild_id, 'manage_guild') inv_codes = await app.db.fetch(""" SELECT code @@ -118,8 +114,6 @@ async def get_guild_invites(guild_id: int): WHERE guild_id = $1 """, guild_id) - # TODO: MANAGE_GUILD permission - inv_codes = [r['code'] for r in inv_codes] invs = await async_map(_get_inv, inv_codes) return jsonify(invs) @@ -127,8 +121,11 @@ async def get_guild_invites(guild_id: int): @bp.route('/channels//invites', methods=['GET']) async def get_channel_invites(channel_id: int): + """Get all invites for a channel.""" user_id = await token_check() + _ctype, guild_id = await channel_check(user_id, channel_id) + await guild_perm_check(user_id, guild_id, 'manage_channels') inv_codes = await app.db.fetch(""" SELECT code @@ -136,8 +133,6 @@ async def get_channel_invites(channel_id: int): WHERE guild_id = $1 AND channel_id = $2 """, guild_id, channel_id) - # TODO: check MANAGE_CHANNELS permission - inv_codes = [r['code'] for r in inv_codes] invs = await async_map(_get_inv, inv_codes) return jsonify(invs) @@ -145,6 +140,7 @@ async def get_channel_invites(channel_id: int): @bp.route('/invite/', methods=['POST']) async def use_invite(invite_code): + """Use an invite.""" user_id = await token_check() guild_id = await app.db.fetchval("""