From 78be4c6fab49832432d66996432c7c3d2e3f701d Mon Sep 17 00:00:00 2001 From: Luna Date: Thu, 6 Dec 2018 00:00:36 -0300 Subject: [PATCH] blueprints.invites: more complete errors --- litecord/blueprints/invites.py | 91 +++++++++++++++++++++++++--------- 1 file changed, 67 insertions(+), 24 deletions(-) diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py index d97c106..0940f07 100644 --- a/litecord/blueprints/invites.py +++ b/litecord/blueprints/invites.py @@ -9,7 +9,7 @@ from logbook import Logger from ..auth import token_check from ..schemas import validate, INVITE from ..enums import ChannelType -from ..errors import BadRequest +from ..errors import BadRequest, Forbidden from .guilds import create_guild_settings from ..utils import async_map @@ -20,31 +20,32 @@ from litecord.blueprints.checks import ( log = Logger(__name__) bp = Blueprint('invites', __name__) -# TODO: Ban handling -async def use_invite(user_id, invite_code): - """Try using an invite""" - inv = await app.db.fetchrow(""" - SELECT guild_id, created_at, max_age, uses, max_uses - FROM invites - WHERE code = $1 - """, invite_code) - - if inv is None: - raise BadRequest('Unknown invite') - - if inv['max_age'] is not 0: - now = datetime.datetime.utcnow() - delta_sec = (now - inv['created_at']).total_seconds() - if delta_sec > inv['max_age']: - await delete_invite(invite_code) - raise BadRequest('Unknown invite (expiried)') +class UnknownInvite(BadRequest): + error_code = 10006 - if inv['max_uses'] is not -1 and inv['uses'] > inv['max_uses']: - await delete_invite(invite_code) - raise BadRequest('Unknown invite (too many uses)') - guild_id = inv['guild_id'] +class InvalidInvite(Forbidden): + error_code = 50020 + + +def gen_inv_code() -> str: + """Generate an invite code. + + This is a primitive and does not guarantee uniqueness. + """ + # TODO: should we really be depending on os.urandom? + raw = os.urandom(7) + raw = base64.b64encode(raw).decode() + + raw = raw.replace('/', '') + raw = raw.replace('+', '') + + return raw[:7] + + +async def invite_precheck(user_id: int, guild_id: int): + """pre-check invite use in the context of a guild.""" joined = await app.db.fetchval(""" SELECT joined_at @@ -55,6 +56,42 @@ async def use_invite(user_id, invite_code): if joined is not None: raise BadRequest('You are already in the guild') + banned = await app.db.fetchval(""" + SELECT reason + FROM bans + WHERE user_id = $1 AND guild_id = $2 + """, user_id, guild_id) + + if banned is not None: + raise InvalidInvite('You are banned.') + + +async def use_invite(user_id, invite_code): + """Try using an invite""" + inv = await app.db.fetchrow(""" + SELECT guild_id, created_at, max_age, uses, max_uses + FROM invites + WHERE code = $1 + """, invite_code) + + if inv is None: + raise UnknownInvite('Unknown invite') + + if inv['max_age'] is not 0: + now = datetime.datetime.utcnow() + delta_sec = (now - inv['created_at']).total_seconds() + + if delta_sec > inv['max_age']: + await delete_invite(invite_code) + raise InvalidInvite('Invite is expired') + + if inv['max_uses'] is not -1 and inv['uses'] > inv['max_uses']: + await delete_invite(invite_code) + raise InvalidInvite('Too many uses') + + guild_id = inv['guild_id'] + await invite_precheck(user_id, guild_id) + await app.db.execute(""" INSERT INTO members (user_id, guild_id) VALUES ($1, $2) @@ -97,6 +134,7 @@ async def use_invite(user_id, invite_code): await app.dispatcher.dispatch_user_guild( user_id, guild_id, 'GUILD_CREATE', guild) + @bp.route('/channels//invites', methods=['POST']) async def create_invite(channel_id): user_id = await token_check() @@ -107,11 +145,16 @@ async def create_invite(channel_id): await channel_perm_check(user_id, channel_id, 'create_invites') chantype = await app.storage.get_chan_type(channel_id) + + # can't create invites for channels that aren't text + # or voice. + + # TODO: once group dms are in, this should change to account. if chantype not in (ChannelType.GUILD_TEXT.value, ChannelType.GUILD_VOICE.value): raise BadRequest('Invalid channel type') - invite_code = base64.b64encode(hashlib.md5(os.urandom(64)).digest()).decode("utf-8").replace("/", "").replace("+", "")[:7] + invite_code = gen_inv_code() await app.db.execute( """