diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py index 12e5cb7..06eee6f 100644 --- a/litecord/blueprints/invites.py +++ b/litecord/blueprints/invites.py @@ -1,3 +1,4 @@ +import datetime import hashlib import os @@ -69,8 +70,16 @@ async def get_invite_2(invite_code: str): return await get_invite(invite_code) -@bp.route('/invites/', methods=['DELETE']) async def delete_invite(invite_code: str): + """Delete an invite.""" + await app.db.fetchval(""" + DELETE FROM invites + WHERE code = $1 + """, invite_code) + + +@bp.route('/invites/', methods=['DELETE']) +async def _delete_invite(invite_code: str): user_id = await token_check() guild_id = await app.db.fetchval(""" @@ -85,18 +94,13 @@ async def delete_invite(invite_code: str): await guild_perm_check(user_id, guild_id, 'manage_channels') inv = await app.storage.get_invite(invite_code) - - await app.db.fetchval(""" - DELETE FROM invites - WHERE code = $1 - """, invite_code) - + await delete_invite(invite_code) return jsonify(inv) @bp.route('/invite/', methods=['DELETE']) -async def delete_invite_2(invite_code: str): - return await delete_invite(invite_code) +async def _delete_invite_2(invite_code: str): + return await _delete_invite(invite_code) async def _get_inv(code): @@ -148,14 +152,27 @@ async def use_invite(invite_code): """Use an invite.""" user_id = await token_check() - guild_id = await app.db.fetchval(""" - SELECT guild_id + inv = await app.db.fetchrow(""" + SELECT guild_id, created_at, max_age, uses, max_uses FROM invites WHERE code = $1 """, invite_code) - if not guild_id: - raise BadRequest('Guild not Found') + if inv is None: + raise BadRequest('Invite not found') + + now = datetime.datetime.utcnow() + delta_sec = (now - inv['created_at']).total_seconds() + + if delta_sec > inv['max_age']: + await delete_invite(invite_code) + raise BadRequest('Invite has expired (age).') + + if inv['uses'] > inv['max_uses']: + await delete_invite(invite_code) + raise BadRequest('Invite has expired (uses).') + + guild_id = inv['guild_id'] joined = await app.db.fetchval(""" SELECT joined_at