diff --git a/litecord/blueprints/__init__.py b/litecord/blueprints/__init__.py index 433b164..557f326 100644 --- a/litecord/blueprints/__init__.py +++ b/litecord/blueprints/__init__.py @@ -6,3 +6,4 @@ from .channels import bp as channels from .webhooks import bp as webhooks from .science import bp as science from .voice import bp as voice +from .invites import bp as invites diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py new file mode 100644 index 0000000..ceacd65 --- /dev/null +++ b/litecord/blueprints/invites.py @@ -0,0 +1,56 @@ +import hashlib +import os + +from quart import Blueprint, request, current_app as app, jsonify +from logbook import Logger + +from ..auth import token_check +from ..schemas import validate, INVITE +from ..enums import ChannelType +from ..errors import BadRequest +from .channels import channel_check + +log = Logger(__name__) +bp = Blueprint('invites', __name__) + + +@bp.route('/channels//invites', methods=['POST']) +async def create_invite(channel_id): + user_id = await token_check() + + j = validate(await request.get_json(), INVITE) + guild_id = await channel_check(user_id, channel_id) + + # TODO: check CREATE_INSTANT_INVITE permission + + chantype = await app.storage.get_chan_type(channel_id) + if chantype not in (ChannelType.GUILD_TEXT.value, + ChannelType.GUILD_VOICE.value): + raise BadRequest('Invalid channel type') + + invite_code = hashlib.md5(os.urandom(64)).hexdigest()[:16] + + await app.db.execute( + """ + INSERT INTO invites + (code, guild_id, channel_id, inviter, max_uses, + max_age, temporary) + VALUES ($1, $2, $3, $4, $5, $6, $7) + """, + invite_code, guild_id, channel_id, user_id, + j['max_uses'], j['max_age'], j['temporary'] + ) + + invite = await app.storage.get_invite(invite_code) + return jsonify(invite) + + +@bp.route('/invites/', methods=['GET']) +async def get_invite(invite_code: str): + inv = await app.storage.get_invite(invite_code) + + if request.args.get('with_counts'): + extra = await app.storage.get_invite_extra(invite_code) + inv.update(extra) + + return jsonify(inv) diff --git a/litecord/schemas.py b/litecord/schemas.py index 1dadf0a..607c6ea 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -1,10 +1,14 @@ import re from cerberus import Validator +from logbook import Logger from .errors import BadRequest from .enums import ActivityType, StatusType + +log = Logger(__name__) + USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A) EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$', re.A) @@ -48,6 +52,8 @@ class LitecordValidator(Validator): def validate(reqjson, schema, raise_err: bool = True): validator = LitecordValidator(schema) + log.debug('Validating {}', reqjson) + if not validator.validate(reqjson): errs = validator.errors @@ -174,3 +180,31 @@ GW_STATUS_UPDATE = { 'schema': GW_ACTIVITY, }, } + +INVITE = { + # max_age in seconds + # 0 for infinite + 'max_age': { + 'type': 'number', + 'min': 0, + 'max': 86400, + + # a day + 'default': 86400 + }, + + # max invite uses + 'max_uses': { + 'type': 'number', + 'min': 0, + + # idk + 'max': 1000, + + # default infinite + 'default': 0 + }, + + 'temporary': {'type': 'boolean', 'required': False, 'default': False}, + 'unique': {'type': 'boolean', 'required': False, 'default': True}, +} diff --git a/litecord/storage.py b/litecord/storage.py index b5a15a4..00698ff 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -13,6 +13,10 @@ async def _dummy(any_id): return str(any_id) +def dict_(val): + return dict(val) if val else None + + class Storage: """Class for common SQL statements.""" def __init__(self, db): @@ -73,10 +77,9 @@ class Storage: drow['system_channel_id'] = str(drow['system_channel_id']) \ if drow['system_channel_id'] else None - return {**drow, **{ - # TODO: those - 'emojis': [], - }} + # TODO: emojis + drow['emojis'] = [] + return drow async def get_user_guilds(self, user_id: int) -> List[int]: """Get all guild IDs a user is on.""" @@ -320,6 +323,15 @@ class Storage: ), }} + async def get_member_ids(self, guild_id: int) -> List[int]: + rows = await self.db.fetch(""" + SELECT user_id + FROM members + WHERE guild_id = $1 + """, guild_id) + + return [r[0] for r in rows] + async def _msg_regex(self, regex, method, content) -> List[Dict]: res = [] @@ -393,3 +405,76 @@ class Storage: return {str(row['target_id']): row['note'] for row in note_rows} + + async def get_invite(self, invite_code: str) -> dict: + """Fetch invite information given its code.""" + invite = await self.db.fetchrow(""" + SELECT code, guild_id, channel_id + FROM invites + WHERE code = $1 + """, invite_code) + + if invite is None: + return None + + dinv = dict_(invite) + + # fetch some guild info + guild = await self.db.fetchrow(""" + SELECT id::text, name, splash, icon + FROM guilds + WHERE id = $1 + """, invite['guild_id']) + + dinv['guild'] = dict(guild) + + chan = await self.get_channel(invite['channel_id']) + dinv['channel'] = { + 'id': chan['id'], + 'name': chan['name'], + 'type': chan['type'], + } + + dinv.pop('guild_id') + dinv.pop('channel_id') + + return dinv + + async def get_invite_extra(self, invite_code: str) -> dict: + """Extra information about the invite, such as + approximate guild and presence counts.""" + guild_id = await self.db.fetchval(""" + SELECT guild_id + FROM invites + WHERE code = $1 + """, invite_code) + + if guild_id is None: + return {} + + mids = await self.get_member_ids(guild_id) + pres = await self.presence.guild_presences(mids, guild_id) + online_count = sum(1 for p in pres if p['status'] == 'online') + + return { + 'approximate_presence_count': online_count, + 'approximate_member_count': len(mids), + } + + async def get_invite_metadata(self, invite_code: str) -> Dict[str, Any]: + """Fetch invite metadata (max_age and friends).""" + invite = await self.db.fetchrow(""" + SELECT code, inviter, created_at, uses, + max_uses, max_age, temporary, created_at, revoked + FROM invites + WHERE code = $1 + """, invite_code) + + if invite is None: + return + + dinv = dict_(invite) + inviter = await self.get_user(invite['inviter']) + dinv['inviter'] = inviter + + return dinv diff --git a/run.py b/run.py index 9b5da85..755a752 100644 --- a/run.py +++ b/run.py @@ -11,7 +11,7 @@ from logbook.compat import redirect_logging import config from litecord.blueprints import gateway, auth, users, guilds, channels, \ - webhooks, science, voice + webhooks, science, voice, invites from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager @@ -53,7 +53,8 @@ bps = { channels: '/channels', webhooks: None, science: None, - voice: '/voice' + voice: '/voice', + invites: None } for bp, suffix in bps.items():