diff --git a/README.md b/README.md index 4658055..3b3756f 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,16 @@ This project is a rewrite of [litecord-reference]. [litecord-reference]: https://gitlab.com/luna/litecord-reference +## Notes + + - There are no testing being run on the current codebase. Which means the code is definitely unstable. + - No voice is planned to be developed, for now. + - You must figure out connecting to the server yourself. Litecord will not distribute + Discord's official client code nor provide ways to modify the client. + ## Install +Requirements: - Python 3.6 or higher - PostgreSQL - [Pipenv] @@ -28,6 +36,10 @@ $ psql -f schema.sql litecord # edit config.py as you wish $ cp config.example.py config.py +# run database migrations (this is a +# required step in setup) +$ pipenv run ./manage.py migrate + # Install all packages: $ pipenv install --dev ``` @@ -42,3 +54,10 @@ Use `--access-log -` to output access logs to stdout. ```sh $ pipenv run hypercorn run:app ``` + +## Updating + +```sh +$ git pull +$ pipenv run ./manage.py migrate +``` diff --git a/litecord/auth.py b/litecord/auth.py index fa8404b..498aa59 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -13,7 +13,11 @@ log = Logger(__name__) async def raw_token_check(token, db=None): db = db or app.db - user_id, _hmac = token.split('.') + + # just try by fragments instead of + # unpacking + fragments = token.split('.') + user_id = fragments[0] try: user_id = base64.b64decode(user_id.encode()) @@ -35,6 +39,17 @@ async def raw_token_check(token, db=None): try: signer.unsign(token) log.debug('login for uid {} successful', user_id) + + # update the user's last_session field + # so that we can keep an exact track of activity, + # even on long-lived single sessions (that can happen + # with people leaving their clients open forever) + await db.execute(""" + UPDATE users + SET last_session = (now() at time zone 'utc') + WHERE id = $1 + """, user_id) + return user_id except BadSignature: log.warning('token failed for uid {}', user_id) @@ -43,6 +58,12 @@ async def raw_token_check(token, db=None): async def token_check(): """Check token information.""" + # first, check if the request info already has a uid + try: + return request.user_id + except AttributeError: + pass + try: token = request.headers['Authorization'] except KeyError: @@ -51,4 +72,6 @@ async def token_check(): if token.startswith('Bot '): token = token.replace('Bot ', '') - return await raw_token_check(token) + user_id = await raw_token_check(token) + request.user_id = user_id + return user_id diff --git a/litecord/blueprints/auth.py b/litecord/blueprints/auth.py index 17e77f3..ce5544e 100644 --- a/litecord/blueprints/auth.py +++ b/litecord/blueprints/auth.py @@ -65,7 +65,7 @@ async def register(): new_id = get_snowflake() - new_discrim = str(random.randint(1, 9999)) + new_discrim = random.randint(1, 9999) new_discrim = '%04d' % new_discrim pwd_hash = await hash_data(password) diff --git a/litecord/blueprints/channel/__init__.py b/litecord/blueprints/channel/__init__.py new file mode 100644 index 0000000..4337684 --- /dev/null +++ b/litecord/blueprints/channel/__init__.py @@ -0,0 +1,3 @@ +from .messages import bp as channel_messages +from .reactions import bp as channel_reactions +from .pins import bp as channel_pins diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py new file mode 100644 index 0000000..4e5c7b2 --- /dev/null +++ b/litecord/blueprints/channel/messages.py @@ -0,0 +1,276 @@ +from quart import Blueprint, request, current_app as app, jsonify + +from logbook import Logger + + +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import channel_check, channel_perm_check +from litecord.blueprints.dms import try_dm_state +from litecord.errors import MessageNotFound, Forbidden, BadRequest +from litecord.enums import MessageType, ChannelType, GUILD_CHANS +from litecord.snowflake import get_snowflake +from litecord.schemas import validate, MESSAGE_CREATE + + +log = Logger(__name__) +bp = Blueprint('channel_messages', __name__) + + +def extract_limit(request, default: int = 50): + try: + limit = int(request.args.get('limit', default)) + + if limit not in range(0, 100): + raise ValueError() + except (TypeError, ValueError): + raise BadRequest('limit not int') + + return limit + + +def query_tuple_from_args(args: dict, limit: int) -> tuple: + before, after = None, None + + if 'around' in request.args: + average = int(limit / 2) + around = int(request.args['around']) + + after = around - average + before = around + average + + elif 'before' in request.args: + before = int(request.args['before']) + elif 'after' in request.args: + before = int(request.args['after']) + + return before, after + + +@bp.route('//messages', methods=['GET']) +async def get_messages(channel_id): + user_id = await token_check() + + # TODO: check READ_MESSAGE_HISTORY permission + ctype, peer_id = await channel_check(user_id, channel_id) + + if ctype == ChannelType.DM: + # make sure both parties will be subbed + # to a dm + await _dm_pre_dispatch(channel_id, user_id) + await _dm_pre_dispatch(channel_id, peer_id) + + limit = extract_limit(request, 50) + + where_clause = '' + before, after = query_tuple_from_args(request.args, limit) + + if before: + where_clause += f'AND id < {before}' + + if after: + where_clause += f'AND id > {after}' + + message_ids = await app.db.fetch(f""" + SELECT id + FROM messages + WHERE channel_id = $1 {where_clause} + ORDER BY id DESC + LIMIT {limit} + """, channel_id) + + result = [] + + for message_id in message_ids: + msg = await app.storage.get_message(message_id['id'], user_id) + + if msg is None: + continue + + result.append(msg) + + log.info('Fetched {} messages', len(result)) + return jsonify(result) + + +@bp.route('//messages/', methods=['GET']) +async def get_single_message(channel_id, message_id): + user_id = await token_check() + await channel_check(user_id, channel_id) + + # TODO: check READ_MESSAGE_HISTORY permissions + message = await app.storage.get_message(message_id, user_id) + + if not message: + raise MessageNotFound() + + return jsonify(message) + + +async def _dm_pre_dispatch(channel_id, peer_id): + """Do some checks pre-MESSAGE_CREATE so we + make sure the receiving party will handle everything.""" + + # check the other party's dm_channel_state + + dm_state = await app.db.fetchval(""" + SELECT dm_id + FROM dm_channel_state + WHERE user_id = $1 AND dm_id = $2 + """, peer_id, channel_id) + + if dm_state: + # the peer already has the channel + # opened, so we don't need to do anything + return + + dm_chan = await app.storage.get_channel(channel_id) + + # dispatch CHANNEL_CREATE so the client knows which + # channel the future event is about + await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan) + + # subscribe the peer to the channel + await app.dispatcher.sub('channel', channel_id, peer_id) + + # insert it on dm_channel_state so the client + # is subscribed on the future + await try_dm_state(peer_id, channel_id) + + +@bp.route('//messages', methods=['POST']) +async def create_message(channel_id): + user_id = await token_check() + ctype, guild_id = await channel_check(user_id, channel_id) + + if ctype in GUILD_CHANS: + await channel_perm_check(user_id, channel_id, 'send_messages') + + j = validate(await request.get_json(), MESSAGE_CREATE) + message_id = get_snowflake() + + # TODO: check connection to the gateway + + mentions_everyone = ('@everyone' in j['content'] and + await channel_perm_check( + user_id, channel_id, 'mention_everyone', False + ) + ) + + is_tts = (j.get('tts', False) and + await channel_perm_check( + user_id, channel_id, 'send_tts_messages', False + )) + + await app.db.execute( + """ + INSERT INTO messages (id, channel_id, author_id, content, tts, + mention_everyone, nonce, message_type) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + """, + message_id, + channel_id, + user_id, + j['content'], + + is_tts, + mentions_everyone, + + int(j.get('nonce', 0)), + MessageType.DEFAULT.value + ) + + payload = await app.storage.get_message(message_id, user_id) + + if ctype == ChannelType.DM: + # guild id here is the peer's ID. + await _dm_pre_dispatch(channel_id, user_id) + await _dm_pre_dispatch(channel_id, guild_id) + + await app.dispatcher.dispatch('channel', channel_id, + 'MESSAGE_CREATE', payload) + + # TODO: dispatch the MESSAGE_CREATE to any mentioning user. + + if ctype == ChannelType.GUILD_TEXT: + for str_uid in payload['mentions']: + uid = int(str_uid) + + await app.db.execute(""" + UPDATE user_read_state + SET mention_count += 1 + WHERE user_id = $1 AND channel_id = $2 + """, uid, channel_id) + + return jsonify(payload) + + +@bp.route('//messages/', methods=['PATCH']) +async def edit_message(channel_id, message_id): + user_id = await token_check() + _ctype, guild_id = await channel_check(user_id, channel_id) + + author_id = await app.db.fetchval(""" + SELECT author_id FROM messages + WHERE messages.id = $1 + """, message_id) + + if not author_id == user_id: + raise Forbidden('You can not edit this message') + + j = await request.get_json() + updated = 'content' in j or 'embed' in j + + if 'content' in j: + await app.db.execute(""" + UPDATE messages + SET content=$1 + WHERE messages.id = $2 + """, j['content'], message_id) + + # TODO: update embed + + message = await app.storage.get_message(message_id, user_id) + + # only dispatch MESSAGE_UPDATE if we actually had any update to start with + if updated: + await app.dispatcher.dispatch('channel', channel_id, + 'MESSAGE_UPDATE', message) + + return jsonify(message) + + +@bp.route('//messages/', methods=['DELETE']) +async def delete_message(channel_id, message_id): + user_id = await token_check() + _ctype, guild_id = await channel_check(user_id, channel_id) + + author_id = await app.db.fetchval(""" + SELECT author_id FROM messages + WHERE messages.id = $1 + """, message_id) + + by_perm = await channel_perm_check( + user_id, channel_id, 'manage_messages', False + ) + + by_ownership = author_id == user_id + + if not by_perm and not by_ownership: + raise Forbidden('You can not delete this message') + + await app.db.execute(""" + DELETE FROM messages + WHERE messages.id = $1 + """, message_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, + 'MESSAGE_DELETE', { + 'id': str(message_id), + 'channel_id': str(channel_id), + + # for lazy guilds + 'guild_id': str(guild_id), + }) + + return '', 204 diff --git a/litecord/blueprints/channel/pins.py b/litecord/blueprints/channel/pins.py new file mode 100644 index 0000000..7d5b42b --- /dev/null +++ b/litecord/blueprints/channel/pins.py @@ -0,0 +1,93 @@ +from quart import Blueprint, current_app as app, request, jsonify + +from litecord.auth import token_check +from litecord.blueprints.checks import channel_check +from litecord.snowflake import snowflake_datetime + +bp = Blueprint('channel_pins', __name__) + + +@bp.route('//pins', methods=['GET']) +async def get_pins(channel_id): + """Get the pins for a channel""" + user_id = await token_check() + await channel_check(user_id, channel_id) + + ids = await app.db.fetch(""" + SELECT message_id + FROM channel_pins + WHERE channel_id = $1 + ORDER BY message_id ASC + """, channel_id) + + ids = [r['message_id'] for r in ids] + res = [] + + for message_id in ids: + message = await app.storage.get_message(message_id) + if message is not None: + res.append(message) + + return jsonify(res) + + +@bp.route('//pins/', methods=['PUT']) +async def add_pin(channel_id, message_id): + """Add a pin to a channel""" + user_id = await token_check() + _ctype, guild_id = await channel_check(user_id, channel_id) + + # TODO: check MANAGE_MESSAGES permission + + await app.db.execute(""" + INSERT INTO channel_pins (channel_id, message_id) + VALUES ($1, $2) + """, channel_id, message_id) + + row = await app.db.fetchrow(""" + SELECT message_id + FROM channel_pins + WHERE channel_id = $1 + ORDER BY message_id ASC + LIMIT 1 + """, channel_id) + + timestamp = snowflake_datetime(row['message_id']) + + await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', { + 'channel_id': str(channel_id), + 'last_pin_timestamp': timestamp.isoformat() + }) + + return '', 204 + + +@bp.route('//pins/', methods=['DELETE']) +async def delete_pin(channel_id, message_id): + user_id = await token_check() + _ctype, guild_id = await channel_check(user_id, channel_id) + + # TODO: check MANAGE_MESSAGES permission + + await app.db.execute(""" + DELETE FROM channel_pins + WHERE channel_id = $1 AND message_id = $2 + """, channel_id, message_id) + + row = await app.db.fetchrow(""" + SELECT message_id + FROM channel_pins + WHERE channel_id = $1 + ORDER BY message_id ASC + LIMIT 1 + """, channel_id) + + timestamp = snowflake_datetime(row['message_id']) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'CHANNEL_PINS_UPDATE', { + 'channel_id': str(channel_id), + 'last_pin_timestamp': timestamp.isoformat() + }) + + return '', 204 diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py new file mode 100644 index 0000000..6db9e6b --- /dev/null +++ b/litecord/blueprints/channel/reactions.py @@ -0,0 +1,225 @@ +from enum import IntEnum + +from quart import Blueprint, request, current_app as app, jsonify +from logbook import Logger + + +from litecord.utils import async_map +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import channel_check +from litecord.blueprints.channel.messages import ( + query_tuple_from_args, extract_limit +) + +from litecord.errors import MessageNotFound, Forbidden, BadRequest +from litecord.enums import GUILD_CHANS + + +log = Logger(__name__) +bp = Blueprint('channel_reactions', __name__) + +BASEPATH = '//messages//reactions' + + +class EmojiType(IntEnum): + CUSTOM = 0 + UNICODE = 1 + + +def emoji_info_from_str(emoji: str) -> tuple: + """Extract emoji information from an emoji string + given on the reaction endpoints.""" + # custom emoji have an emoji of name:id + # unicode emoji just have the raw unicode. + + # try checking if the emoji is custom or unicode + emoji_type = 0 if ':' in emoji else 1 + emoji_type = EmojiType(emoji_type) + + # extract the emoji id OR the unicode value of the emoji + # depending if it is custom or not + emoji_id = (int(emoji.split(':')[1]) + if emoji_type == EmojiType.CUSTOM + else emoji) + + emoji_name = emoji.split(':')[0] + + return emoji_type, emoji_id, emoji_name + + +def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: + print(emoji_type, emoji_id, emoji_name) + return { + 'id': None if emoji_type == EmojiType.UNICODE else emoji_id, + 'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id + } + + +def _make_payload(user_id, channel_id, message_id, partial): + return { + 'user_id': str(user_id), + 'channel_id': str(channel_id), + 'message_id': str(message_id), + 'emoji': partial + } + + +@bp.route(f'{BASEPATH}//@me', methods=['PUT']) +async def add_reaction(channel_id: int, message_id: int, emoji: str): + """Put a reaction.""" + user_id = await token_check() + + # TODO: check READ_MESSAGE_HISTORY permission + # and ADD_REACTIONS. look on route docs. + ctype, guild_id = await channel_check(user_id, channel_id) + + emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) + + await app.db.execute( + """ + INSERT INTO message_reactions (message_id, user_id, + emoji_type, emoji_id, emoji_text) + VALUES ($1, $2, $3, $4, $5) + """, message_id, user_id, emoji_type, + + # if it is custom, we put the emoji_id on emoji_id + # column, if it isn't, we put it on emoji_text + # column. + emoji_id if emoji_type == EmojiType.CUSTOM else None, + emoji_id if emoji_type == EmojiType.UNICODE else None + ) + + partial = partial_emoji(emoji_type, emoji_id, emoji_name) + payload = _make_payload(user_id, channel_id, message_id, partial) + + if ctype in GUILD_CHANS: + payload['guild_id'] = str(guild_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'MESSAGE_REACTION_ADD', payload) + + return '', 204 + + +def emoji_sql(emoji_type, emoji_id, emoji_name, param=4): + """Extract SQL clauses to search for specific emoji + in the message_reactions table.""" + param = f'${param}' + + # know which column to filter with + where_ext = (f'AND emoji_id = {param}' + if emoji_type == EmojiType.CUSTOM else + f'AND emoji_text = {param}') + + # which emoji to remove (custom or unicode) + main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name + + return where_ext, main_emoji + + +def _emoji_sql_simple(emoji: str, param=4): + """Simpler version of _emoji_sql for functions that + don't need the results from emoji_info_from_str.""" + emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) + return emoji_sql(emoji_type, emoji_id, emoji_name, param) + + +async def remove_reaction(channel_id: int, message_id: int, + user_id: int, emoji: str): + ctype, guild_id = await channel_check(user_id, channel_id) + + emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) + where_ext, main_emoji = emoji_sql(emoji_type, emoji_id, emoji_name) + + await app.db.execute( + f""" + DELETE FROM message_reactions + WHERE message_id = $1 + AND user_id = $2 + AND emoji_type = $3 + {where_ext} + """, message_id, user_id, emoji_type, main_emoji) + + partial = partial_emoji(emoji_type, emoji_id, emoji_name) + payload = _make_payload(user_id, channel_id, message_id, partial) + + if ctype in GUILD_CHANS: + payload['guild_id'] = str(guild_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload) + + +@bp.route(f'{BASEPATH}//@me', methods=['DELETE']) +async def remove_own_reaction(channel_id, message_id, emoji): + """Remove a reaction.""" + user_id = await token_check() + + await remove_reaction(channel_id, message_id, user_id, emoji) + + return '', 204 + + +@bp.route(f'{BASEPATH}//', methods=['DELETE']) +async def remove_user_reaction(channel_id, message_id, emoji, other_id): + """Remove a reaction made by another user.""" + await token_check() + + # TODO: check MANAGE_MESSAGES permission (and use user_id + # from token_check to do it) + await remove_reaction(channel_id, message_id, other_id, emoji) + + return '', 204 + + +@bp.route(f'{BASEPATH}/', methods=['GET']) +async def list_users_reaction(channel_id, message_id, emoji): + """Get the list of all users who reacted with a certain emoji.""" + user_id = await token_check() + + # this is not using either ctype or guild_id + # that are returned by channel_check + await channel_check(user_id, channel_id) + + limit = extract_limit(request, 25) + before, after = query_tuple_from_args(request.args, limit) + + before_clause = 'AND user_id < $2' if before else '' + after_clause = 'AND user_id > $3' if after else '' + + where_ext, main_emoji = _emoji_sql_simple(emoji, 4) + + rows = await app.db.fetch(f""" + SELECT user_id + FROM message_reactions + WHERE message_id = $1 {before_clause} {after_clause} {where_ext} + """, message_id, before, after, main_emoji) + + user_ids = [r['user_id'] for r in rows] + users = await async_map(app.storage.get_user, user_ids) + return jsonify(users) + + +@bp.route(f'{BASEPATH}', methods=['DELETE']) +async def remove_all_reactions(channel_id, message_id): + """Remove all reactions in a message.""" + user_id = await token_check() + + # TODO: check MANAGE_MESSAGES permission + ctype, guild_id = await channel_check(user_id, channel_id) + + await app.db.execute(""" + DELETE FROM message_reactions + WHERE message_id = $1 + """, message_id) + + payload = { + 'channel_id': str(channel_id), + 'message_id': str(message_id), + } + + if ctype in GUILD_CHANS: + payload['guild_id'] = str(guild_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 2ab084f..ed5e62e 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -3,14 +3,14 @@ import time from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger -from ..auth import token_check -from ..snowflake import get_snowflake, snowflake_datetime -from ..enums import ChannelType, MessageType, GUILD_CHANS -from ..errors import Forbidden, ChannelNotFound, MessageNotFound -from ..schemas import validate, MESSAGE_CREATE +from litecord.auth import token_check +from litecord.enums import ChannelType, GUILD_CHANS +from litecord.errors import ChannelNotFound +from litecord.schemas import ( + validate, CHAN_UPDATE, CHAN_OVERWRITE +) -from .checks import channel_check, guild_check -from .dms import try_dm_state +from litecord.blueprints.checks import channel_check, channel_perm_check log = Logger(__name__) bp = Blueprint('channels', __name__) @@ -136,6 +136,7 @@ async def guild_cleanup(channel_id): @bp.route('/', methods=['DELETE']) async def close_channel(channel_id): + """Close or delete a channel.""" user_id = await token_check() chan_type = await app.storage.get_chan_type(channel_id) @@ -212,287 +213,199 @@ async def close_channel(channel_id): # TODO: group dm pass - return '', 404 + raise ChannelNotFound() -@bp.route('//messages', methods=['GET']) -async def get_messages(channel_id): - user_id = await token_check() - await channel_check(user_id, channel_id) - - # TODO: before, after, around keys - - message_ids = await app.db.fetch(f""" - SELECT id - FROM messages - WHERE channel_id = $1 - ORDER BY id DESC - LIMIT 100 - """, channel_id) - - result = [] - - for message_id in message_ids: - msg = await app.storage.get_message(message_id['id']) - - if msg is None: - continue - - result.append(msg) - - log.info('Fetched {} messages', len(result)) - return jsonify(result) +async def _update_pos(channel_id, pos: int): + await app.db.execute(""" + UPDATE guild_channels + SET position = $1 + WHERE id = $2 + """, pos, channel_id) -@bp.route('//messages/', methods=['GET']) -async def get_single_message(channel_id, message_id): - user_id = await token_check() - await channel_check(user_id, channel_id) - - # TODO: check READ_MESSAGE_HISTORY permissions - message = await app.storage.get_message(message_id) - - if not message: - raise MessageNotFound() - - return jsonify(message) +async def _mass_chan_update(guild_id, channel_ids: int): + for channel_id in channel_ids: + chan = await app.storage.get_channel(channel_id) + await app.dispatcher.dispatch( + 'guild', guild_id, 'CHANNEL_UPDATE', chan) -async def _dm_pre_dispatch(channel_id, peer_id): - """Do some checks pre-MESSAGE_CREATE so we - make sure the receiving party will handle everything.""" +async def _process_overwrites(channel_id: int, overwrites: list): + for overwrite in overwrites: - # check the other party's dm_channel_state + # 0 for user overwrite, 1 for role overwrite + target_type = 0 if overwrite['type'] == 'user' else 1 + target_role = None if target_type == 0 else overwrite['id'] + target_user = overwrite['id'] if target_type == 0 else None - dm_state = await app.db.fetchval(""" - SELECT dm_id - FROM dm_channel_state - WHERE user_id = $1 AND dm_id = $2 - """, peer_id, channel_id) - - if dm_state: - # the peer already has the channel - # opened, so we don't need to do anything - return - - dm_chan = await app.storage.get_channel(channel_id) - - # dispatch CHANNEL_CREATE so the client knows which - # channel the future event is about - await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan) - - # subscribe the peer to the channel - await app.dispatcher.sub('channel', channel_id, peer_id) - - # insert it on dm_channel_state so the client - # is subscribed on the future - await try_dm_state(peer_id, channel_id) + await app.db.execute( + """ + INSERT INTO channel_overwrites + (channel_id, target_type, target_role, + target_user, allow, deny) + VALUES + ($1, $2, $3, $4, $5, $6) + ON CONFLICT ON CONSTRAINT channel_overwrites_uniq + DO + UPDATE + SET allow = $5, deny = $6 + WHERE channel_overwrites.channel_id = $1 + AND channel_overwrites.target_type = $2 + AND channel_overwrites.target_role = $3 + AND channel_overwrites.target_user = $4 + """, + channel_id, target_type, + target_role, target_user, + overwrite['allow'], overwrite['deny']) -@bp.route('//messages', methods=['POST']) -async def create_message(channel_id): +@bp.route('//permissions/', methods=['PUT']) +async def put_channel_overwrite(channel_id: int, overwrite_id: int): + """Insert or modify a channel overwrite.""" user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) - j = validate(await request.get_json(), MESSAGE_CREATE) - message_id = get_snowflake() + if ctype not in GUILD_CHANS: + raise ChannelNotFound('Only usable for guild channels.') - # TODO: check SEND_MESSAGES permission - # TODO: check connection to the gateway + await channel_perm_check(user_id, guild_id, 'manage_roles') - await app.db.execute( - """ - INSERT INTO messages (id, channel_id, author_id, content, tts, - mention_everyone, nonce, message_type) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - """, - message_id, - channel_id, - user_id, - j['content'], - - # TODO: check SEND_TTS_MESSAGES - j.get('tts', False), - - # TODO: check MENTION_EVERYONE permissions - '@everyone' in j['content'], - int(j.get('nonce', 0)), - MessageType.DEFAULT.value + j = validate( + # inserting a fake id on the payload so validation passes through + {**await request.get_json(), **{'id': -1}}, + CHAN_OVERWRITE ) - payload = await app.storage.get_message(message_id) - - if ctype == ChannelType.DM: - # guild id here is the peer's ID. - await _dm_pre_dispatch(channel_id, guild_id) + await _process_overwrites(channel_id, [{ + 'allow': j['allow'], + 'deny': j['deny'], + 'type': j['type'], + 'id': overwrite_id + }]) - await app.dispatcher.dispatch('channel', channel_id, - 'MESSAGE_CREATE', payload) - - # TODO: dispatch the MESSAGE_CREATE to any mentioning user. - - if ctype == ChannelType.GUILD_TEXT: - for str_uid in payload['mentions']: - uid = int(str_uid) - - await app.db.execute(""" - UPDATE user_read_state - SET mention_count += 1 - WHERE user_id = $1 AND channel_id = $2 - """, uid, channel_id) - - return jsonify(payload) + await _mass_chan_update(guild_id, [channel_id]) + return '', 204 -@bp.route('//messages/', methods=['PATCH']) -async def edit_message(channel_id, message_id): - user_id = await token_check() - _ctype, guild_id = await channel_check(user_id, channel_id) - - author_id = await app.db.fetchval(""" - SELECT author_id FROM messages - WHERE messages.id = $1 - """, message_id) - - if not author_id == user_id: - raise Forbidden('You can not edit this message') - - j = await request.get_json() - updated = 'content' in j or 'embed' in j - - if 'content' in j: +async def _update_channel_common(channel_id, guild_id: int, j: dict): + if 'name' in j: await app.db.execute(""" - UPDATE messages - SET content=$1 - WHERE messages.id = $2 - """, j['content'], message_id) + UPDATE guild_channels + SET name = $1 + WHERE id = $2 + """, j['name'], channel_id) - # TODO: update embed + if 'position' in j: + channel_data = await app.storage.get_channel_data(guild_id) - message = await app.storage.get_message(message_id) + chans = [None * len(channel_data)] + for chandata in channel_data: + chans.insert(chandata['position'], int(chandata['id'])) - # only dispatch MESSAGE_UPDATE if we actually had any update to start with - if updated: - await app.dispatcher.dispatch('channel', channel_id, - 'MESSAGE_UPDATE', message) + # are we changing to the left or to the right? - return jsonify(message) + # left: [channel1, channel2, ..., channelN-1, channelN] + # becomes + # [channel1, channelN-1, channel2, ..., channelN] + # so we can say that the "main change" is + # channelN-1 going to the position channel2 + # was occupying. + current_pos = chans.index(channel_id) + new_pos = j['position'] + + # if the new position is bigger than the current one, + # we're making a left shift of all the channels that are + # beyond the current one, to make space + left_shift = new_pos > current_pos + + # find all channels that we'll have to shift + shift_block = (chans[current_pos:new_pos] + if left_shift else + chans[new_pos:current_pos] + ) + + shift = -1 if left_shift else 1 + + # do the shift (to the left or to the right) + await app.db.executemany(""" + UPDATE guild_channels + SET position = position + $1 + WHERE id = $2 + """, [(shift, chan_id) for chan_id in shift_block]) + + await _mass_chan_update(guild_id, shift_block) + + # since theres now an empty slot, move current channel to it + await _update_pos(channel_id, new_pos) + + if 'channel_overwrites' in j: + overwrites = j['channel_overwrites'] + await _process_overwrites(channel_id, overwrites) -@bp.route('//messages/', methods=['DELETE']) -async def delete_message(channel_id, message_id): +async def _common_guild_chan(channel_id, j: dict): + # common updates to the guild_channels table + for field in [field for field in j.keys() + if field in ('nsfw', 'parent_id')]: + await app.db.execute(f""" + UPDATE guild_channels + SET {field} = $1 + WHERE id = $2 + """, j[field], channel_id) + + +async def _update_text_channel(channel_id: int, j: dict): + # first do the specific ones related to guild_text_channels + for field in [field for field in j.keys() + if field in ('topic', 'rate_limit_per_user')]: + await app.db.execute(f""" + UPDATE guild_text_channels + SET {field} = $1 + WHERE id = $2 + """, j[field], channel_id) + + await _common_guild_chan(channel_id, j) + + +async def _update_voice_channel(channel_id: int, j: dict): + # first do the specific ones in guild_voice_channels + for field in [field for field in j.keys() + if field in ('bitrate', 'user_limit')]: + await app.db.execute(f""" + UPDATE guild_voice_channels + SET {field} = $1 + WHERE id = $2 + """, j[field], channel_id) + + # yes, i'm letting voice channels have nsfw, you cant stop me + await _common_guild_chan(channel_id, j) + + +@bp.route('/', methods=['PUT', 'PATCH']) +async def update_channel(channel_id): + """Update a channel's information""" user_id = await token_check() - _ctype, guild_id = await channel_check(user_id, channel_id) + ctype, guild_id = await channel_check(user_id, channel_id) - author_id = await app.db.fetchval(""" - SELECT author_id FROM messages - WHERE messages.id = $1 - """, message_id) + if ctype not in GUILD_CHANS: + raise ChannelNotFound('Can not edit non-guild channels.') - # TODO: MANAGE_MESSAGES permission check - if author_id != user_id: - raise Forbidden('You can not delete this message') + await channel_perm_check(user_id, channel_id, 'manage_channels') + j = validate(await request.get_json(), CHAN_UPDATE) - await app.db.execute(""" - DELETE FROM messages - WHERE messages.id = $1 - """, message_id) + # TODO: categories? + update_handler = { + ChannelType.GUILD_TEXT: _update_text_channel, + ChannelType.GUILD_VOICE: _update_voice_channel, + }[ctype] - await app.dispatcher.dispatch( - 'channel', channel_id, - 'MESSAGE_DELETE', { - 'id': str(message_id), - 'channel_id': str(channel_id), + await _update_channel_common(channel_id, guild_id, j) + await update_handler(channel_id, j) - # for lazy guilds - 'guild_id': str(guild_id), - }) - - return '', 204 - - -@bp.route('//pins', methods=['GET']) -async def get_pins(channel_id): - user_id = await token_check() - await channel_check(user_id, channel_id) - - ids = await app.db.fetch(""" - SELECT message_id - FROM channel_pins - WHERE channel_id = $1 - ORDER BY message_id ASC - """, channel_id) - - ids = [r['message_id'] for r in ids] - res = [] - - for message_id in ids: - message = await app.storage.get_message(message_id) - if message is not None: - res.append(message) - - return jsonify(message) - - -@bp.route('//pins/', methods=['PUT']) -async def add_pin(channel_id, message_id): - user_id = await token_check() - _ctype, guild_id = await channel_check(user_id, channel_id) - - # TODO: check MANAGE_MESSAGES permission - - await app.db.execute(""" - INSERT INTO channel_pins (channel_id, message_id) - VALUES ($1, $2) - """, channel_id, message_id) - - row = await app.db.fetchrow(""" - SELECT message_id - FROM channel_pins - WHERE channel_id = $1 - ORDER BY message_id ASC - LIMIT 1 - """, channel_id) - - timestamp = snowflake_datetime(row['message_id']) - - await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', { - 'channel_id': str(channel_id), - 'last_pin_timestamp': timestamp.isoformat() - }) - - return '', 204 - - -@bp.route('//pins/', methods=['DELETE']) -async def delete_pin(channel_id, message_id): - user_id = await token_check() - _ctype, guild_id = await channel_check(user_id, channel_id) - - # TODO: check MANAGE_MESSAGES permission - - await app.db.execute(""" - DELETE FROM channel_pins - WHERE channel_id = $1 AND message_id = $2 - """, channel_id, message_id) - - row = await app.db.fetchrow(""" - SELECT message_id - FROM channel_pins - WHERE channel_id = $1 - ORDER BY message_id ASC - LIMIT 1 - """, channel_id) - - timestamp = snowflake_datetime(row['message_id']) - - await app.dispatcher.dispatch( - 'channel', channel_id, 'CHANNEL_PINS_UPDATE', { - 'channel_id': str(channel_id), - 'last_pin_timestamp': timestamp.isoformat() - }) - - return '', 204 + chan = await app.storage.get_channel(channel_id) + await app.dispatcher.dispatch('guild', guild_id, 'CHANNEL_UPDATE', chan) + return jsonify(chan) @bp.route('//typing', methods=['POST']) @@ -518,21 +431,18 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): if not message_id: message_id = await app.storage.chan_last_message(channel_id) - res = await app.db.execute(""" - UPDATE user_read_state - - SET last_message_id = $1, - mention_count = 0 - - WHERE user_id = $2 AND channel_id = $3 - """, message_id, user_id, channel_id) - - if res == 'UPDATE 0': - await app.db.execute(""" - INSERT INTO user_read_state - (user_id, channel_id, last_message_id, mention_count) - VALUES ($1, $2, $3, $4) - """, user_id, channel_id, message_id, 0) + await app.db.execute(""" + INSERT INTO user_read_state + (user_id, channel_id, last_message_id, mention_count) + VALUES + ($1, $2, $3, 0) + ON CONFLICT ON CONSTRAINT user_read_state_pkey + DO + UPDATE + SET last_message_id = $3, mention_count = 0 + WHERE user_read_state.user_id = $1 + AND user_read_state.channel_id = $2 + """, user_id, channel_id, message_id) if guild_id: await app.dispatcher.dispatch_user_guild( @@ -551,6 +461,7 @@ async def channel_ack(user_id, guild_id, channel_id, message_id: int = None): @bp.route('//messages//ack', methods=['POST']) async def ack_channel(channel_id, message_id): + """Acknowledge a channel.""" user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) @@ -569,6 +480,7 @@ async def ack_channel(channel_id, message_id): @bp.route('//messages/ack', methods=['DELETE']) async def delete_read_state(channel_id): + """Delete the read state of a channel.""" user_id = await token_check() await channel_check(user_id, channel_id) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 5cfc225..e051c11 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -1,7 +1,10 @@ from quart import current_app as app -from ..enums import ChannelType, GUILD_CHANS -from ..errors import GuildNotFound, ChannelNotFound +from litecord.enums import ChannelType, GUILD_CHANS +from litecord.errors import ( + GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions +) +from litecord.permissions import base_permissions, get_permissions async def guild_check(user_id: int, guild_id: int): @@ -16,6 +19,21 @@ async def guild_check(user_id: int, guild_id: int): raise GuildNotFound('guild not found') +async def guild_owner_check(user_id: int, guild_id: int): + """Check if a user is the owner of the guild.""" + owner_id = await app.db.fetchval(""" + SELECT owner_id + FROM guilds + WHERE guilds.id = $1 + """, guild_id) + + if not owner_id: + raise GuildNotFound() + + if user_id != owner_id: + raise Forbidden('You are not the owner of the guild') + + async def channel_check(user_id, channel_id): """Check if the current user is authorized to read the channel's information.""" @@ -39,3 +57,27 @@ async def channel_check(user_id, channel_id): if ctype == ChannelType.DM: peer_id = await app.storage.get_dm_peer(channel_id, user_id) return ctype, peer_id + + +async def guild_perm_check(user_id, guild_id, permission: str): + """Check guild permissions for a user.""" + base_perms = await base_permissions(user_id, guild_id) + hasperm = getattr(base_perms.bits, permission) + + if not hasperm: + raise MissingPermissions('Missing permissions.') + + +async def channel_perm_check(user_id, channel_id, + permission: str, raise_err=True): + """Check channel permissions for a user.""" + base_perms = await get_permissions(user_id, channel_id) + hasperm = getattr(base_perms.bits, permission) + + print(base_perms) + print(base_perms.binary) + + if not hasperm and raise_err: + raise MissingPermissions('Missing permissions.') + + return hasperm diff --git a/litecord/blueprints/dms.py b/litecord/blueprints/dms.py index 7d625a2..2a83c54 100644 --- a/litecord/blueprints/dms.py +++ b/litecord/blueprints/dms.py @@ -38,41 +38,47 @@ async def try_dm_state(user_id: int, dm_id: int): """, user_id, dm_id) +async def jsonify_dm(dm_id: int, user_id: int): + dm_chan = await app.storage.get_dm(dm_id, user_id) + return jsonify(dm_chan) + + async def create_dm(user_id, recipient_id): """Create a new dm with a user, or get the existing DM id if it already exists.""" + + dm_id = await app.db.fetchval(""" + SELECT id + FROM dm_channels + WHERE (party1_id = $1 OR party2_id = $1) AND + (party1_id = $2 OR party2_id = $2) + """, user_id, recipient_id) + + if dm_id: + return await jsonify_dm(dm_id, user_id) + + # if no dm was found, create a new one + dm_id = get_snowflake() + await app.db.execute(""" + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, dm_id, ChannelType.DM.value) - try: - await app.db.execute(""" - INSERT INTO channels (id, channel_type) - VALUES ($1, $2) - """, dm_id, ChannelType.DM.value) + await app.db.execute(""" + INSERT INTO dm_channels (id, party1_id, party2_id) + VALUES ($1, $2, $3) + """, dm_id, user_id, recipient_id) - await app.db.execute(""" - INSERT INTO dm_channels (id, party1_id, party2_id) - VALUES ($1, $2, $3) - """, dm_id, user_id, recipient_id) + # the dm state is something we use + # to give the currently "open dms" + # on the client. - # the dm state is something we use - # to give the currently "open dms" - # on the client. + # we don't open a dm for the peer/recipient + # until the user sends a message. + await try_dm_state(user_id, dm_id) - # we don't open a dm for the peer/recipient - # until the user sends a message. - await try_dm_state(user_id, dm_id) - - except UniqueViolationError: - # the dm already exists - dm_id = await app.db.fetchval(""" - SELECT id - FROM dm_channels - WHERE (party1_id = $1 OR party2_id = $1) AND - (party2_id = $2 OR party2_id = $2) - """, user_id, recipient_id) - - dm = await app.storage.get_dm(dm_id, user_id) - return jsonify(dm) + return await jsonify_dm(dm_id, user_id) @bp.route('/@me/channels', methods=['POST']) diff --git a/litecord/blueprints/gateway.py b/litecord/blueprints/gateway.py index 9f1c4df..a301cac 100644 --- a/litecord/blueprints/gateway.py +++ b/litecord/blueprints/gateway.py @@ -1,3 +1,5 @@ +import time + from quart import Blueprint, jsonify, current_app as app from ..auth import token_check @@ -6,12 +8,14 @@ bp = Blueprint('gateway', __name__) def get_gw(): + """Get the gateway's web""" proto = 'wss://' if app.config['IS_SSL'] else 'ws://' return f'{proto}{app.config["WEBSOCKET_URL"]}/ws' @bp.route('/gateway') def api_gateway(): + """Get the raw URL.""" return jsonify({ 'url': get_gw() }) @@ -27,9 +31,25 @@ async def api_gateway_bot(): WHERE user_id = $1 """, user_id) - shards = max(int(guild_count / 1200), 1) + shards = max(int(guild_count / 1000), 1) + + # get _ws.session ratelimit + ratelimit = app.ratelimiter.get_ratelimit('_ws.session') + bucket = ratelimit.get_bucket(user_id) + + # timestamp of bucket reset + reset_ts = bucket._window + bucket.second + + # how many seconds until bucket reset + reset_after_ts = reset_ts - time.time() return jsonify({ 'url': get_gw(), 'shards': shards, + + 'session_start_limit': { + 'total': bucket.requests, + 'remaining': bucket._tokens, + 'reset_after': int(reset_after_ts * 1000), + } }) diff --git a/litecord/blueprints/guild/__init__.py b/litecord/blueprints/guild/__init__.py new file mode 100644 index 0000000..36c36b6 --- /dev/null +++ b/litecord/blueprints/guild/__init__.py @@ -0,0 +1,4 @@ +from .roles import bp as guild_roles +from .members import bp as guild_members +from .channels import bp as guild_channels +from .mod import bp as guild_mod diff --git a/litecord/blueprints/guild/channels.py b/litecord/blueprints/guild/channels.py new file mode 100644 index 0000000..21d89ec --- /dev/null +++ b/litecord/blueprints/guild/channels.py @@ -0,0 +1,180 @@ +from quart import Blueprint, request, current_app as app, jsonify + +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import guild_check, guild_owner_check +from litecord.snowflake import get_snowflake +from litecord.errors import BadRequest +from litecord.enums import ChannelType +from litecord.schemas import ( + validate, ROLE_UPDATE_POSITION +) +from litecord.blueprints.guild.roles import gen_pairs + + +bp = Blueprint('guild_channels', __name__) + + +async def _specific_chan_create(channel_id, ctype, **kwargs): + if ctype == ChannelType.GUILD_TEXT: + await app.db.execute(""" + INSERT INTO guild_text_channels (id, topic) + VALUES ($1, $2) + """, channel_id, kwargs.get('topic', '')) + elif ctype == ChannelType.GUILD_VOICE: + await app.db.execute( + """ + INSERT INTO guild_voice_channels (id, bitrate, user_limit) + VALUES ($1, $2, $3) + """, + channel_id, + kwargs.get('bitrate', 64), + kwargs.get('user_limit', 0) + ) + + +async def create_guild_channel(guild_id: int, channel_id: int, + ctype: ChannelType, **kwargs): + """Create a channel in a guild.""" + await app.db.execute(""" + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, channel_id, ctype.value) + + # calc new pos + max_pos = await app.db.fetchval(""" + SELECT MAX(position) + FROM guild_channels + WHERE guild_id = $1 + """, guild_id) + + # account for the first channel in a guild too + max_pos = max_pos or 0 + + # all channels go to guild_channels + await app.db.execute(""" + INSERT INTO guild_channels (id, guild_id, name, position) + VALUES ($1, $2, $3, $4) + """, channel_id, guild_id, kwargs['name'], max_pos + 1) + + # the rest of sql magic is dependant on the channel + # we're creating (a text or voice or category), + # so we use this function. + await _specific_chan_create(channel_id, ctype, **kwargs) + + +@bp.route('//channels', methods=['GET']) +async def get_guild_channels(guild_id): + """Get the list of channels in a guild.""" + user_id = await token_check() + await guild_check(user_id, guild_id) + + return jsonify( + await app.storage.get_channel_data(guild_id)) + + +@bp.route('//channels', methods=['POST']) +async def create_channel(guild_id): + """Create a channel in a guild.""" + user_id = await token_check() + j = await request.get_json() + + # TODO: check permissions for MANAGE_CHANNELS + await guild_check(user_id, guild_id) + + channel_type = j.get('type', ChannelType.GUILD_TEXT) + channel_type = ChannelType(channel_type) + + if channel_type not in (ChannelType.GUILD_TEXT, + ChannelType.GUILD_VOICE): + raise BadRequest('Invalid channel type') + + new_channel_id = get_snowflake() + await create_guild_channel( + guild_id, new_channel_id, channel_type, **j) + + # TODO: do a better method + # subscribe the currently subscribed users to the new channel + # by getting all user ids and subscribing each one by one. + + # since GuildDispatcher calls Storage.get_channel_ids, + # it will subscribe all users to the newly created channel. + guild_pubsub = app.dispatcher.backends['guild'] + user_ids = guild_pubsub.state[guild_id] + for uid in user_ids: + await app.dispatcher.sub('guild', guild_id, uid) + + chan = await app.storage.get_channel(new_channel_id) + await app.dispatcher.dispatch_guild( + guild_id, 'CHANNEL_CREATE', chan) + return jsonify(chan) + + +async def _chan_update_dispatch(guild_id: int, channel_id: int): + """Fetch new information about the channel and dispatch + a single CHANNEL_UPDATE event to the guild.""" + chan = await app.storage.get_channel(channel_id) + await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_UPDATE', chan) + + +async def _do_single_swap(guild_id: int, pair: tuple): + """Do a single channel swap, dispatching + the CHANNEL_UPDATE events for after the swap""" + pair1, pair2 = pair + channel_1, new_pos_1 = pair1 + channel_2, new_pos_2 = pair2 + + # do the swap in a transaction. + conn = await app.db.acquire() + + async with conn.transaction(): + await conn.executemany(""" + UPDATE guild_channels + SET position = $1 + WHERE id = $2 AND guild_id = $3 + """, [ + (new_pos_1, channel_1, guild_id), + (new_pos_2, channel_2, guild_id)]) + + await app.db.release(conn) + + await _chan_update_dispatch(guild_id, channel_1) + await _chan_update_dispatch(guild_id, channel_2) + + +async def _do_channel_swaps(guild_id: int, swap_pairs: list): + """Swap channel pairs' positions, given the list + of pairs to do. + + Dispatches CHANNEL_UPDATEs to the guild. + """ + for pair in swap_pairs: + await _do_single_swap(guild_id, pair) + + +@bp.route('//channels', methods=['PATCH']) +async def modify_channel_pos(guild_id): + """Change positions of channels in a guild.""" + user_id = await token_check() + + # TODO: check MANAGE_CHANNELS + await guild_owner_check(user_id, guild_id) + + # same thing as guild.roles, so we use + # the same schema and all. + raw_j = await request.get_json() + j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION) + j = j['roles'] + + channels = await app.storage.get_channel_data(guild_id) + + channel_positions = {chan['position']: int(chan['id']) + for chan in channels} + + swap_pairs = gen_pairs( + j, + channel_positions + ) + + await _do_channel_swaps(guild_id, swap_pairs) + + return '', 204 diff --git a/litecord/blueprints/guild/members.py b/litecord/blueprints/guild/members.py new file mode 100644 index 0000000..8d90435 --- /dev/null +++ b/litecord/blueprints/guild/members.py @@ -0,0 +1,172 @@ +from quart import Blueprint, request, current_app as app, jsonify + +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import guild_check +from litecord.errors import BadRequest +from litecord.schemas import ( + validate, MEMBER_UPDATE +) +from litecord.blueprints.checks import guild_owner_check + + +bp = Blueprint('guild_members', __name__) + + +@bp.route('//members/', methods=['GET']) +async def get_guild_member(guild_id, member_id): + """Get a member's information in a guild.""" + user_id = await token_check() + await guild_check(user_id, guild_id) + member = await app.storage.get_single_member(guild_id, member_id) + return jsonify(member) + + +@bp.route('//members', methods=['GET']) +async def get_members(guild_id): + """Get members inside a guild.""" + user_id = await token_check() + await guild_check(user_id, guild_id) + + j = await request.get_json() + + limit, after = int(j.get('limit', 1)), j.get('after', 0) + + if limit < 1 or limit > 1000: + raise BadRequest('limit not in 1-1000 range') + + user_ids = await app.db.fetch(f""" + SELECT user_id + WHERE guild_id = $1, user_id > $2 + LIMIT {limit} + ORDER BY user_id ASC + """, guild_id, after) + + user_ids = [r[0] for r in user_ids] + members = await app.storage.get_member_multi(guild_id, user_ids) + return jsonify(members) + + +async def _update_member_roles(guild_id: int, member_id: int, + wanted_roles: list): + """Update the roles a member has.""" + + # first, fetch all current roles + roles = await app.db.fetch(""" + SELECT role_id from member_roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + roles = [r['role_id'] for r in roles] + + roles = set(roles) + wanted_roles = set(wanted_roles) + + # first, we need to find all added roles: + # roles that are on wanted_roles but + # not on roles + added_roles = wanted_roles - roles + + # and then the removed roles + # which are roles in roles, but not + # in wanted_roles + removed_roles = roles - wanted_roles + + conn = await app.db.acquire() + + async with conn.transaction(): + # add roles + await app.db.executemany(""" + INSERT INTO member_roles (user_id, guild_id, role_id) + VALUES ($1, $2, $3) + """, [(member_id, guild_id, role_id) + for role_id in added_roles]) + + # remove roles + await app.db.executemany(""" + DELETE FROM member_roles + WHERE + user_id = $1 + AND guild_id = $2 + AND role_id = $3 + """, [(member_id, guild_id, role_id) + for role_id in removed_roles]) + + await app.db.release(conn) + + +@bp.route('//members/', methods=['PATCH']) +async def modify_guild_member(guild_id, member_id): + """Modify a members' information in a guild.""" + user_id = await token_check() + await guild_owner_check(user_id, guild_id) + + j = validate(await request.get_json(), MEMBER_UPDATE) + + if 'nick' in j: + # TODO: check MANAGE_NICKNAMES + + await app.db.execute(""" + UPDATE members + SET nickname = $1 + WHERE user_id = $2 AND guild_id = $3 + """, j['nick'], member_id, guild_id) + + if 'mute' in j: + # TODO: check MUTE_MEMBERS + + await app.db.execute(""" + UPDATE members + SET muted = $1 + WHERE user_id = $2 AND guild_id = $3 + """, j['mute'], member_id, guild_id) + + if 'deaf' in j: + # TODO: check DEAFEN_MEMBERS + + await app.db.execute(""" + UPDATE members + SET deafened = $1 + WHERE user_id = $2 AND guild_id = $3 + """, j['deaf'], member_id, guild_id) + + if 'channel_id' in j: + # TODO: check MOVE_MEMBERS and CONNECT to the channel + # TODO: change the member's voice channel + pass + + if 'roles' in j: + # TODO: check permissions + await _update_member_roles(guild_id, member_id, j['roles']) + + member = await app.storage.get_member_data_one(guild_id, member_id) + member.pop('joined_at') + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ + 'guild_id': str(guild_id) + }, **member}) + + return '', 204 + + +@bp.route('//members/@me/nick', methods=['PATCH']) +async def update_nickname(guild_id): + """Update a member's nickname in a guild.""" + user_id = await token_check() + await guild_check(user_id, guild_id) + + j = await request.get_json() + + await app.db.execute(""" + UPDATE members + SET nickname = $1 + WHERE user_id = $2 AND guild_id = $3 + """, j['nick'], user_id, guild_id) + + member = await app.storage.get_member_data_one(guild_id, user_id) + member.pop('joined_at') + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ + 'guild_id': str(guild_id) + }, **member}) + + return j['nick'] diff --git a/litecord/blueprints/guild/mod.py b/litecord/blueprints/guild/mod.py new file mode 100644 index 0000000..0461da3 --- /dev/null +++ b/litecord/blueprints/guild/mod.py @@ -0,0 +1,185 @@ +from quart import Blueprint, request, current_app as app, jsonify + +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import guild_owner_check + +from litecord.schemas import validate, GUILD_PRUNE + +bp = Blueprint('guild_moderation', __name__) + + +async def remove_member(guild_id: int, member_id: int): + """Do common tasks related to deleting a member from the guild, + such as dispatching GUILD_DELETE and GUILD_MEMBER_REMOVE.""" + + await app.db.execute(""" + DELETE FROM members + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', { + 'guild_id': guild_id, + 'unavailable': False, + }) + + await app.dispatcher.unsub('guild', guild_id, member_id) + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', { + 'guild_id': str(guild_id), + 'user': await app.storage.get_user(member_id), + }) + + +async def remove_member_multi(guild_id: int, members: list): + """Remove multiple members.""" + for member_id in members: + await remove_member(guild_id, member_id) + + +@bp.route('//members/', methods=['DELETE']) +async def kick_guild_member(guild_id, member_id): + """Remove a member from a guild.""" + user_id = await token_check() + + # TODO: check KICK_MEMBERS permission + await guild_owner_check(user_id, guild_id) + await remove_member(guild_id, member_id) + return '', 204 + + +@bp.route('//bans', methods=['GET']) +async def get_bans(guild_id): + user_id = await token_check() + + # TODO: check BAN_MEMBERS permission + await guild_owner_check(user_id, guild_id) + + bans = await app.db.fetch(""" + SELECT user_id, reason + FROM bans + WHERE bans.guild_id = $1 + """, guild_id) + + res = [] + + for ban in bans: + res.append({ + 'reason': ban['reason'], + 'user': await app.storage.get_user(ban['user_id']) + }) + + return jsonify(res) + + +@bp.route('//bans/', methods=['PUT']) +async def create_ban(guild_id, member_id): + user_id = await token_check() + + # TODO: check BAN_MEMBERS permission + await guild_owner_check(user_id, guild_id) + + j = await request.get_json() + + await app.db.execute(""" + INSERT INTO bans (guild_id, user_id, reason) + VALUES ($1, $2, $3) + """, guild_id, member_id, j.get('reason', '')) + + await remove_member(guild_id, member_id) + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', { + 'guild_id': str(guild_id), + 'user': await app.storage.get_user(member_id) + }) + + return '', 204 + + +@bp.route('//bans/', methods=['DELETE']) +async def remove_ban(guild_id, banned_id): + user_id = await token_check() + + # TODO: check BAN_MEMBERS permission + await guild_owner_check(guild_id, user_id) + + res = await app.db.execute(""" + DELETE FROM bans + WHERE guild_id = $1 AND user_id = $@ + """, guild_id, banned_id) + + # we don't really need to dispatch GUILD_BAN_REMOVE + # when no bans were actually removed. + if res == 'DELETE 0': + return '', 204 + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_REMOVE', { + 'guild_id': str(guild_id), + 'user': await app.storage.get_user(banned_id) + }) + + return '', 204 + + +async def get_prune(guild_id: int, days: int) -> list: + """Get all members in a guild that: + + - did not login in ``days`` days. + - don't have any roles. + """ + # a good solution would be in pure sql. + member_ids = await app.db.fetch(f""" + SELECT id + FROM users + JOIN members + ON members.guild_id = $1 AND members.user_id = users.id + WHERE users.last_session < (now() - (interval '{days} days')) + """, guild_id) + + member_ids = [r['id'] for r in member_ids] + members = [] + + for member_id in member_ids: + role_count = await app.db.fetchval(""" + SELECT COUNT(*) + FROM member_roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + if role_count == 0: + members.append(member_id) + + return members + + +@bp.route('//prune', methods=['GET']) +async def get_guild_prune_count(guild_id): + user_id = await token_check() + + # TODO: check KICK_MEMBERS + await guild_owner_check(user_id, guild_id) + + j = validate(await request.get_json(), GUILD_PRUNE) + days = j['days'] + member_ids = await get_prune(guild_id, days) + + return jsonify({ + 'pruned': len(member_ids), + }) + + +@bp.route('//prune', methods=['POST']) +async def begin_guild_prune(guild_id): + user_id = await token_check() + + # TODO: check KICK_MEMBERS + await guild_owner_check(user_id, guild_id) + + j = validate(await request.get_json(), GUILD_PRUNE) + days = j['days'] + member_ids = await get_prune(guild_id, days) + + app.loop.create_task(remove_member_multi(guild_id, member_ids)) + + return jsonify({ + 'pruned': len(member_ids) + }) diff --git a/litecord/blueprints/guild/roles.py b/litecord/blueprints/guild/roles.py new file mode 100644 index 0000000..003b4e4 --- /dev/null +++ b/litecord/blueprints/guild/roles.py @@ -0,0 +1,315 @@ +from typing import List, Dict, Any, Union + +from quart import Blueprint, request, current_app as app, jsonify + +from litecord.auth import token_check + +from litecord.blueprints.checks import ( + guild_check, guild_owner_check +) +from litecord.schemas import ( + validate, ROLE_CREATE, ROLE_UPDATE, ROLE_UPDATE_POSITION +) + +from litecord.snowflake import get_snowflake +from litecord.utils import dict_get + +DEFAULT_EVERYONE_PERMS = 104324161 +bp = Blueprint('guild_roles', __name__) + + +@bp.route('//roles', methods=['GET']) +async def get_guild_roles(guild_id): + """Get all roles in a guild.""" + user_id = await token_check() + await guild_check(user_id, guild_id) + + return jsonify( + await app.storage.get_role_data(guild_id) + ) + + +async def create_role(guild_id, name: str, **kwargs): + """Create a role in a guild.""" + new_role_id = get_snowflake() + + # TODO: use @everyone's perm number + default_perms = dict_get(kwargs, 'default_perms', DEFAULT_EVERYONE_PERMS) + + max_pos = await app.db.fetchval(""" + SELECT MAX(position) + FROM roles + WHERE guild_id = $1 + """, guild_id) + + await app.db.execute( + """ + INSERT INTO roles (id, guild_id, name, color, + hoist, position, permissions, managed, mentionable) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + """, + new_role_id, + guild_id, + name, + dict_get(kwargs, 'color', 0), + dict_get(kwargs, 'hoist', False), + + # set position = 0 when there isn't any + # other role (when we're creating the + # @everyone role) + max_pos + 1 if max_pos is not None else 0, + int(dict_get(kwargs, 'permissions', default_perms)), + False, + dict_get(kwargs, 'mentionable', False) + ) + + role = await app.storage.get_role(new_role_id, guild_id) + await app.dispatcher.dispatch_guild( + guild_id, 'GUILD_ROLE_CREATE', { + 'guild_id': str(guild_id), + 'role': role, + }) + + return role + + +@bp.route('//roles', methods=['POST']) +async def create_guild_role(guild_id: int): + """Add a role to a guild""" + user_id = await token_check() + + # TODO: use check_guild and MANAGE_ROLES permission + await guild_owner_check(user_id, guild_id) + + # client can just send null + j = validate(await request.get_json() or {}, ROLE_CREATE) + + role_name = j['name'] + j.pop('name') + + role = await create_role(guild_id, role_name, **j) + + return jsonify(role) + + +async def _role_update_dispatch(role_id: int, guild_id: int): + """Dispatch a GUILD_ROLE_UPDATE with updated information on a role.""" + role = await app.storage.get_role(role_id, guild_id) + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_UPDATE', { + 'guild_id': str(guild_id), + 'role': role, + }) + + return role + + +async def _role_pairs_update(guild_id: int, pairs: list): + """Update the roles' positions. + + Dispatches GUILD_ROLE_UPDATE for all roles being updated. + """ + for pair in pairs: + pair_1, pair_2 = pair + + role_1, new_pos_1 = pair_1 + role_2, new_pos_2 = pair_2 + + conn = await app.db.acquire() + async with conn.transaction(): + # update happens in a transaction + # so we don't fuck it up + await conn.execute(""" + UPDATE roles + SET position = $1 + WHERE roles.id = $2 + """, new_pos_1, role_1) + + await conn.execute(""" + UPDATE roles + SET position = $1 + WHERE roles.id = $2 + """, new_pos_2, role_2) + + await app.db.release(conn) + + # the route fires multiple Guild Role Update. + await _role_update_dispatch(role_1, guild_id) + await _role_update_dispatch(role_2, guild_id) + + +def gen_pairs(list_of_changes: List[Dict[str, int]], + current_state: Dict[int, int], + blacklist: List[int] = None) -> List[tuple]: + """Generate a list of pairs that, when applied to the database, + will generate the desired state given in list_of_changes. + + We must check if the given list_of_changes isn't overwriting an + element's (such as a role or a channel) position to an existing one, + without there having an already existing change for the other one. + + Here's a pratical explanation with roles: + + R1 (in position RP1) wants to be in the same position + as R2 (currently in position RP2). + + So, if we did the simpler approach, list_of_changes + would just contain the preferred change: (R1, RP2). + + With gen_pairs, there MUST be a (R2, RP1) in list_of_changes, + if there is, the given result in gen_pairs will be a pair + ((R1, RP2), (R2, RP1)) which is then used to actually + update the roles' positions in a transaction. + + Parameters + ---------- + list_of_changes: + A list of dictionaries with ``id`` and ``position`` + fields, describing the preferred changes. + current_state: + Dictionary containing the current state of the list + of elements (roles or channels). Points position + to element ID. + blacklist: + List of IDs that shouldn't be moved. + + Returns + ------- + list + List of swaps to do to achieve the preferred + state given by ``list_of_changes``. + """ + pairs = [] + blacklist = blacklist or [] + + preferred_state = {element['id']: element['position'] + for element in list_of_changes} + + for blacklisted_id in blacklist: + preferred_state.pop(blacklisted_id) + + # for each change, we must find a matching change + # in the same list, so we can make a swap pair + for change in list_of_changes: + element_1, new_pos_1 = change['id'], change['position'] + + # check current pairs + # so we don't repeat an element + flag = False + + for pair in pairs: + if (element_1, new_pos_1) in pair: + flag = True + + # skip if found + if flag: + continue + + # search if there is a role/channel in the + # position we want to change to + element_2 = current_state.get(new_pos_1) + + # if there is, is that existing channel being + # swapped to another position? + new_pos_2 = preferred_state.get(element_2) + + # if its being swapped to leave space, add it + # to the pairs list + if new_pos_2: + pairs.append( + ((element_1, new_pos_1), (element_2, new_pos_2)) + ) + + return pairs + + +@bp.route('//roles', methods=['PATCH']) +async def update_guild_role_positions(guild_id): + """Update the positions for a bunch of roles.""" + user_id = await token_check() + + # TODO: check MANAGE_ROLES + await guild_owner_check(user_id, guild_id) + + raw_j = await request.get_json() + + # we need to do this hackiness because thats + # cerberus for ya. + j = validate({'roles': raw_j}, ROLE_UPDATE_POSITION) + + # extract the list out + j = j['roles'] + + all_roles = await app.storage.get_role_data(guild_id) + + # we'll have to calculate pairs of changing roles, + # then do the changes, etc. + roles_pos = {role['position']: int(role['id']) for role in all_roles} + + # TODO: check if the user can even change the roles in the first place, + # preferrably when we have a proper perms system. + + pairs = gen_pairs( + j, + roles_pos, + + # always ignore people trying to change + # the @everyone's role position + [guild_id] + ) + + await _role_pairs_update(guild_id, pairs) + + # return the list of all roles back + return jsonify(await app.storage.get_role_data(guild_id)) + + +@bp.route('//roles/', methods=['PATCH']) +async def update_guild_role(guild_id, role_id): + """Update a single role's information.""" + user_id = await token_check() + + # TODO: check MANAGE_ROLES + await guild_owner_check(user_id, guild_id) + + j = validate(await request.get_json(), ROLE_UPDATE) + + # we only update ints on the db, not Permissions + j['permissions'] = int(j['permissions']) + + for field in j: + await app.db.execute(f""" + UPDATE roles + SET {field} = $1 + WHERE roles.id = $2 AND roles.guild_id = $3 + """, j[field], role_id, guild_id) + + role = await _role_update_dispatch(role_id, guild_id) + return jsonify(role) + + +@bp.route('//roles/', methods=['DELETE']) +async def delete_guild_role(guild_id, role_id): + """Delete a role. + + Dispatches GUILD_ROLE_DELETE. + """ + user_id = await token_check() + + # TODO: check MANAGE_ROLES + await guild_owner_check(user_id, guild_id) + + res = await app.db.execute(""" + DELETE FROM roles + WHERE guild_id = $1 AND id = $2 + """, guild_id, role_id) + + if res == 'DELETE 0': + return '', 204 + + await app.dispatcher.dispatch_guild(guild_id, 'GUILD_ROLE_DELETE', { + 'guild_id': str(guild_id), + 'role_id': str(role_id), + }) + + return '', 204 diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index d43b033..cd40ec2 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -1,31 +1,23 @@ from quart import Blueprint, request, current_app as app, jsonify +from litecord.blueprints.guild.channels import create_guild_channel +from litecord.blueprints.guild.roles import ( + create_role, DEFAULT_EVERYONE_PERMS +) + from ..auth import token_check from ..snowflake import get_snowflake from ..enums import ChannelType -from ..errors import Forbidden, GuildNotFound, BadRequest -from ..schemas import validate, GUILD_UPDATE +from ..schemas import ( + validate, GUILD_CREATE, GUILD_UPDATE +) from .channels import channel_ack -from .checks import guild_check +from .checks import guild_check, guild_owner_check + bp = Blueprint('guilds', __name__) -async def guild_owner_check(user_id: int, guild_id: int): - """Check if a user is the owner of the guild.""" - owner_id = await app.db.fetchval(""" - SELECT owner_id - FROM guilds - WHERE guild_id = $1 - """, guild_id) - - if not owner_id: - raise GuildNotFound() - - if user_id != owner_id: - raise Forbidden('You are not the owner of the guild') - - async def create_guild_settings(guild_id: int, user_id: int): """Create guild settings for the user joining the guild.""" @@ -48,10 +40,59 @@ async def create_guild_settings(guild_id: int, user_id: int): """, m_notifs, user_id, guild_id) +async def add_member(guild_id: int, user_id: int): + """Add a user to a guild.""" + await app.db.execute(""" + INSERT INTO members (user_id, guild_id) + VALUES ($1, $2) + """, user_id, guild_id) + + await create_guild_settings(guild_id, user_id) + + +async def guild_create_roles_prep(guild_id: int, roles: list): + """Create roles in preparation in guild create.""" + # by reaching this point in the code that means + # roles is not nullable, which means + # roles has at least one element, so we can access safely. + + # the first member in the roles array + # are patches to the @everyone role + everyone_patches = roles[0] + for field in everyone_patches: + await app.db.execute(f""" + UPDATE roles + SET {field}={everyone_patches[field]} + WHERE roles.id = $1 + """, guild_id) + + default_perms = (everyone_patches.get('permissions') + or DEFAULT_EVERYONE_PERMS) + + # from the 2nd and forward, + # should be treated as new roles + for role in roles[1:]: + await create_role( + guild_id, role['name'], default_perms=default_perms, **role + ) + + +async def guild_create_channels_prep(guild_id: int, channels: list): + """Create channels pre-guild create""" + for channel_raw in channels: + channel_id = get_snowflake() + ctype = ChannelType(channel_raw['type']) + + await create_guild_channel(guild_id, channel_id, ctype) + + @bp.route('', methods=['POST']) async def create_guild(): + """Create a new guild, assigning + the user creating it as the owner and + making them join.""" user_id = await token_check() - j = await request.get_json() + j = validate(await request.get_json(), GUILD_CREATE) guild_id = get_snowflake() @@ -66,36 +107,37 @@ async def create_guild(): j.get('default_message_notifications', 0), j.get('explicit_content_filter', 0)) - await app.db.execute(""" - INSERT INTO members (user_id, guild_id) - VALUES ($1, $2) - """, user_id, guild_id) + await add_member(guild_id, user_id) - await create_guild_settings(guild_id, user_id) + # create the default @everyone role (everyone has it by default, + # so we don't insert that in the table) + # we also don't use create_role because the id of the role + # is the same as the id of the guild, and create_role + # generates a new snowflake. await app.db.execute(""" INSERT INTO roles (id, guild_id, name, position, permissions) VALUES ($1, $2, $3, $4, $5) - """, guild_id, guild_id, '@everyone', 0, 104324161) + """, guild_id, guild_id, '@everyone', 0, DEFAULT_EVERYONE_PERMS) + # add the @everyone role to the guild creator + await app.db.execute(""" + INSERT INTO member_roles (user_id, guild_id, role_id) + VALUES ($1, $2, $3) + """, user_id, guild_id, guild_id) + + # create a single #general channel. general_id = get_snowflake() - await app.db.execute(""" - INSERT INTO channels (id, channel_type) - VALUES ($1, $2) - """, general_id, ChannelType.GUILD_TEXT.value) + await create_guild_channel( + guild_id, general_id, ChannelType.GUILD_TEXT, + name='general') - await app.db.execute(""" - INSERT INTO guild_channels (id, guild_id, name, position) - VALUES ($1, $2, $3, $4) - """, general_id, guild_id, 'general', 0) + if j.get('roles'): + await guild_create_roles_prep(guild_id, j['roles']) - await app.db.execute(""" - INSERT INTO guild_text_channels (id) - VALUES ($1) - """, general_id) - - # TODO: j['roles'] and j['channels'] + if j.get('channels'): + await guild_create_channels_prep(guild_id, j['channels']) guild_total = await app.storage.get_guild_full(guild_id, user_id, 250) @@ -106,21 +148,22 @@ async def create_guild(): @bp.route('/', methods=['GET']) async def get_guild(guild_id): + """Get a single guilds' information.""" user_id = await token_check() + await guild_check(user_id, guild_id) - gj = await app.storage.get_guild(guild_id, user_id) - gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250) - - return jsonify({**gj, **gj_extra}) + return jsonify( + await app.storage.get_guild_full(guild_id, user_id, 250) + ) @bp.route('/', methods=['UPDATE']) async def update_guild(guild_id): user_id = await token_check() - await guild_check(user_id, guild_id) - j = validate(await request.get_json(), GUILD_UPDATE) # TODO: check MANAGE_GUILD + await guild_check(user_id, guild_id) + j = validate(await request.get_json(), GUILD_UPDATE) if 'owner_id' in j: await guild_owner_check(user_id, guild_id) @@ -139,8 +182,6 @@ async def update_guild(guild_id): """, j['name'], guild_id) if 'region' in j: - # TODO: check region value - await app.db.execute(""" UPDATE guilds SET region = $1 @@ -167,15 +208,14 @@ async def update_guild(guild_id): WHERE guild_id = $2 """, j[field], guild_id) - # return guild object - gj = await app.storage.get_guild(guild_id, user_id) - gj_extra = await app.storage.get_guild_extra(guild_id, user_id, 250) + guild = await app.storage.get_guild_full( + guild_id, user_id + ) - gj_total = {**gj, **gj_extra} + await app.dispatcher.dispatch_guild( + guild_id, 'GUILD_UPDATE', guild) - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_UPDATE', gj_total) - - return jsonify({**gj, **gj_extra}) + return jsonify(guild) @bp.route('/', methods=['DELETE']) @@ -185,7 +225,7 @@ async def delete_guild(guild_id): await guild_owner_check(user_id, guild_id) await app.db.execute(""" - DELETE FROM guild + DELETE FROM guilds WHERE guilds.id = $1 """, guild_id) @@ -202,264 +242,12 @@ async def delete_guild(guild_id): return '', 204 -@bp.route('//channels', methods=['GET']) -async def get_guild_channels(guild_id): - user_id = await token_check() - await guild_check(user_id, guild_id) - - channels = await app.storage.get_channel_data(guild_id) - return jsonify(channels) - - -@bp.route('//channels', methods=['POST']) -async def create_channel(guild_id): - user_id = await token_check() - j = await request.get_json() - - # TODO: check permissions for MANAGE_CHANNELS - await guild_check(user_id, guild_id) - - new_channel_id = get_snowflake() - channel_type = j.get('type', ChannelType.GUILD_TEXT) - - channel_type = ChannelType(channel_type) - - if channel_type not in (ChannelType.GUILD_TEXT, - ChannelType.GUILD_VOICE): - raise BadRequest('Invalid channel type') - - await app.db.execute(""" - INSERT INTO channels (id, channel_type) - VALUES ($1, $2) - """, new_channel_id, channel_type.value) - - max_pos = await app.db.fetchval(""" - SELECT MAX(position) - FROM guild_channels - WHERE guild_id = $1 - """, guild_id) - - if channel_type == ChannelType.GUILD_TEXT: - await app.db.execute(""" - INSERT INTO guild_channels (id, guild_id, name, position) - VALUES ($1, $2, $3, $4) - """, new_channel_id, guild_id, j['name'], max_pos + 1) - - await app.db.execute(""" - INSERT INTO guild_text_channels (id) - VALUES ($1) - """, new_channel_id) - - elif channel_type == ChannelType.GUILD_VOICE: - raise NotImplementedError() - - chan = await app.storage.get_channel(new_channel_id) - await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_CREATE', chan) - return jsonify(chan) - - -@bp.route('//channels', methods=['PATCH']) -async def modify_channel_pos(guild_id): - user_id = await token_check() - await guild_check(user_id, guild_id) - await request.get_json() - - # TODO: this route - - raise NotImplementedError - - -@bp.route('//members/', methods=['GET']) -async def get_guild_member(guild_id, member_id): - user_id = await token_check() - await guild_check(user_id, guild_id) - - member = await app.storage.get_single_member(guild_id, member_id) - return jsonify(member) - - -@bp.route('//members', methods=['GET']) -async def get_members(guild_id): - user_id = await token_check() - await guild_check(user_id, guild_id) - - j = await request.get_json() - - limit, after = int(j.get('limit', 1)), j.get('after', 0) - - if limit < 1 or limit > 1000: - raise BadRequest('limit not in 1-1000 range') - - user_ids = await app.db.fetch(f""" - SELECT user_id - WHERE guild_id = $1, user_id > $2 - LIMIT {limit} - ORDER BY user_id ASC - """, guild_id, after) - - user_ids = [r[0] for r in user_ids] - members = await app.storage.get_member_multi(guild_id, user_ids) - return jsonify(members) - - -@bp.route('//members/', methods=['PATCH']) -async def modify_guild_member(guild_id, member_id): - j = await request.get_json() - - if 'nick' in j: - # TODO: check MANAGE_NICKNAMES - - await app.db.execute(""" - UPDATE members - SET nickname = $1 - WHERE user_id = $2 AND guild_id = $3 - """, j['nick'], member_id, guild_id) - - if 'mute' in j: - # TODO: check MUTE_MEMBERS - - await app.db.execute(""" - UPDATE members - SET muted = $1 - WHERE user_id = $2 AND guild_id = $3 - """, j['mute'], member_id, guild_id) - - if 'deaf' in j: - # TODO: check DEAFEN_MEMBERS - - await app.db.execute(""" - UPDATE members - SET deafened = $1 - WHERE user_id = $2 AND guild_id = $3 - """, j['deaf'], member_id, guild_id) - - if 'channel_id' in j: - # TODO: check MOVE_MEMBERS - # TODO: change the member's voice channel - pass - - member = await app.storage.get_member_data_one(guild_id, member_id) - member.pop('joined_at') - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ - 'guild_id': str(guild_id) - }, **member}) - - return '', 204 - - -@bp.route('//members/@me/nick', methods=['PATCH']) -async def update_nickname(guild_id): - user_id = await token_check() - await guild_check(user_id, guild_id) - - j = await request.get_json() - - await app.db.execute(""" - UPDATE members - SET nickname = $1 - WHERE user_id = $2 AND guild_id = $3 - """, j['nick'], user_id, guild_id) - - member = await app.storage.get_member_data_one(guild_id, user_id) - member.pop('joined_at') - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_UPDATE', {**{ - 'guild_id': str(guild_id) - }, **member}) - - return j['nick'] - - -@bp.route('//members/', methods=['DELETE']) -async def kick_member(guild_id, member_id): - user_id = await token_check() - - # TODO: check KICK_MEMBERS permission - await guild_owner_check(user_id, guild_id) - - await app.db.execute(""" - DELETE FROM members - WHERE guild_id = $1 AND user_id = $2 - """, guild_id, member_id) - - await app.dispatcher.dispatch_user(user_id, 'GUILD_DELETE', { - 'guild_id': guild_id, - 'unavailable': False, - }) - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', { - 'guild': guild_id, - 'user': await app.storage.get_user(member_id), - }) - - return '', 204 - - -@bp.route('//bans', methods=['GET']) -async def get_bans(guild_id): - user_id = await token_check() - - # TODO: check BAN_MEMBERS permission - await guild_owner_check(user_id, guild_id) - - bans = await app.db.fetch(""" - SELECT user_id, reason - FROM bans - WHERE bans.guild_id = $1 - """, guild_id) - - res = [] - - for ban in bans: - res.append({ - 'reason': ban['reason'], - 'user': await app.storage.get_user(ban['user_id']) - }) - - return jsonify(res) - - -@bp.route('//bans/', methods=['PUT']) -async def create_ban(guild_id, member_id): - user_id = await token_check() - - # TODO: check BAN_MEMBERS permission - await guild_owner_check(user_id, guild_id) - - j = await request.get_json() - - await app.db.execute(""" - INSERT INTO bans (guild_id, user_id, reason) - VALUES ($1, $2, $3) - """, guild_id, member_id, j.get('reason', '')) - - await app.db.execute(""" - DELETE FROM members - WHERE guild_id = $1 AND user_id = $2 - """, guild_id, user_id) - - await app.dispatcher.dispatch_user(member_id, 'GUILD_DELETE', { - 'guild_id': guild_id, - 'unavailable': False, - }) - - await app.dispatcher.unsub('guild', guild_id, member_id) - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_MEMBER_REMOVE', { - 'guild': guild_id, - 'user': await app.storage.get_user(member_id), - }) - - await app.dispatcher.dispatch_guild(guild_id, 'GUILD_BAN_ADD', {**{ - 'guild': guild_id, - }, **(await app.storage.get_user(member_id))}) - - return '', 204 - - @bp.route('//messages/search') async def search_messages(guild_id): + """Search messages in a guild. + + This is an undocumented route. + """ user_id = await token_check() await guild_check(user_id, guild_id) @@ -474,6 +262,7 @@ async def search_messages(guild_id): @bp.route('//ack', methods=['POST']) async def ack_guild(guild_id): + """ACKnowledge all messages in the guild.""" user_id = await token_check() await guild_check(user_id, guild_id) diff --git a/litecord/blueprints/invites.py b/litecord/blueprints/invites.py index 246b8d0..2172ea7 100644 --- a/litecord/blueprints/invites.py +++ b/litecord/blueprints/invites.py @@ -185,7 +185,7 @@ async def use_invite(invite_code): }) # subscribe new member to guild, so they get events n stuff - app.dispatcher.sub_guild(guild_id, user_id) + await app.dispatcher.sub('guild', guild_id, user_id) # tell the new member that theres the guild it just joined. # we use dispatch_user_guild so that we send the GUILD_CREATE diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 8ab4706..cd47013 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -219,9 +219,11 @@ async def get_me_guilds(): partial = await app.db.fetchrow(""" SELECT id::text, name, icon, owner_id FROM guilds - WHERE guild_id = $1 + WHERE guilds.id = $1 """, guild_id) + partial = dict(partial) + # TODO: partial['permissions'] partial['owner'] = partial['owner_id'] == user_id partial.pop('owner_id') @@ -279,10 +281,11 @@ async def put_note(target_id: int): INSERT INTO notes (user_id, target_id, note) VALUES ($1, $2, $3) - ON CONFLICT DO UPDATE SET + ON CONFLICT ON CONSTRAINT notes_pkey + DO UPDATE SET note = $3 - WHERE - user_id = $1 AND target_id = $2 + WHERE notes.user_id = $1 + AND notes.target_id = $2 """, user_id, target_id, note) await app.dispatcher.dispatch_user(user_id, 'USER_NOTE_UPDATE', { @@ -315,7 +318,8 @@ async def patch_current_settings(): await app.db.execute(f""" UPDATE user_settings SET {key}=$1 - """, j[key]) + WHERE id = $2 + """, j[key], user_id) settings = await app.storage.get_user_settings(user_id) await app.dispatcher.dispatch_user( @@ -444,20 +448,20 @@ async def patch_guild_settings(guild_id: int): continue for field in chan_overrides: - res = await app.db.execute(f""" - UPDATE guild_settings_channel_overrides - SET {field} = $1 - WHERE user_id = $2 - AND guild_id = $3 - AND channel_id = $4 - """, chan_overrides[field], user_id, guild_id, chan_id) - - if res == 'UPDATE 0': - await app.db.execute(f""" - INSERT INTO guild_settings_channel_overrides - (user_id, guild_id, channel_id, {field}) - VALUES ($1, $2, $3, $4) - """, user_id, guild_id, chan_id, chan_overrides[field]) + await app.db.execute(f""" + INSERT INTO guild_settings_channel_overrides + (user_id, guild_id, channel_id, {field}) + VALUES + ($1, $2, $3, $4) + ON CONFLICT + ON CONSTRAINT guild_settings_channel_overrides_pkey + DO + UPDATE + SET {field} = $4 + WHERE guild_settings_channel_overrides.user_id = $1 + AND guild_settings_channel_overrides.guild_id = $2 + AND guild_settings_channel_overrides.channel_id = $3 + """, user_id, guild_id, chan_id, chan_overrides[field]) settings = await app.storage.get_guild_settings_one(user_id, guild_id) diff --git a/litecord/dispatcher.py b/litecord/dispatcher.py index 009ac5e..10d2f7a 100644 --- a/litecord/dispatcher.py +++ b/litecord/dispatcher.py @@ -4,7 +4,8 @@ from typing import List, Any from logbook import Logger from .pubsub import GuildDispatcher, MemberDispatcher, \ - UserDispatcher, ChannelDispatcher, FriendDispatcher + UserDispatcher, ChannelDispatcher, FriendDispatcher, \ + LazyGuildDispatcher log = Logger(__name__) @@ -35,6 +36,7 @@ class EventDispatcher: 'channel': ChannelDispatcher(self), 'user': UserDispatcher(self), 'friend': FriendDispatcher(self), + 'lazy_guild': LazyGuildDispatcher(self), } async def action(self, backend_str: str, action: str, key, identifier): @@ -104,6 +106,15 @@ class EventDispatcher: for key in keys: await self.dispatch(backend_str, key, *args, **kwargs) + async def dispatch_filter(self, backend_str: str, + key: Any, func, *args): + """Dispatch to a backend that only accepts + (event, data) arguments with an optional filter + function.""" + backend = self.backends[backend_str] + key = backend.KEY_TYPE(key) + return await backend.dispatch_filter(key, func, *args) + async def reset(self, backend_str: str, key: Any): """Reset the bucket in the given backend.""" backend = self.backends[backend_str] diff --git a/litecord/errors.py b/litecord/errors.py index fe4f130..70afcf9 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -29,16 +29,24 @@ class NotFound(LitecordError): status_code = 404 -class GuildNotFound(LitecordError): - status_code = 404 +class GuildNotFound(NotFound): + error_code = 10004 -class ChannelNotFound(LitecordError): - status_code = 404 +class ChannelNotFound(NotFound): + error_code = 10003 -class MessageNotFound(LitecordError): - status_code = 404 +class MessageNotFound(NotFound): + error_code = 10008 + + +class Ratelimited(LitecordError): + status_code = 429 + + +class MissingPermissions(Forbidden): + error_code = 50013 class WebsocketClose(Exception): diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index 56d00bc..5185b1d 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -1,18 +1,68 @@ +import asyncio + from typing import List, Dict, Any from collections import defaultdict +from websockets.exceptions import ConnectionClosed from logbook import Logger -from .state import GatewayState +from litecord.gateway.state import GatewayState +from litecord.gateway.opcodes import OP log = Logger(__name__) +class ManagerClose(Exception): + pass + + +class StateDictWrapper: + """Wrap a mapping so that any kind of access to the mapping while the + state manager is closed raises a ManagerClose error""" + def __init__(self, state_manager, mapping): + self.state_manager = state_manager + self._map = mapping + + def _check_closed(self): + if self.state_manager.closed: + raise ManagerClose() + + def __getitem__(self, key): + self._check_closed() + return self._map[key] + + def __delitem__(self, key): + self._check_closed() + del self._map[key] + + def __setitem__(self, key, value): + if not self.state_manager.accept_new: + raise ManagerClose() + + self._check_closed() + self._map[key] = value + + def __iter__(self): + return self._map.__iter__() + + def pop(self, key): + return self._map.pop(key) + + def values(self): + return self._map.values() + + class StateManager: """Manager for gateway state information.""" def __init__(self): + #: closed flag + self.closed = False + + #: accept new states? + self.accept_new = True + # { # user_id: { # session_id: GatewayState, @@ -20,7 +70,10 @@ class StateManager: # }, # user_id_2: {}, ... # } - self.states = defaultdict(dict) + self.states = StateDictWrapper(self, defaultdict(dict)) + + #: raw mapping from session ids to GatewayState + self.states_raw = StateDictWrapper(self, {}) def insert(self, state: GatewayState): """Insert a new state object.""" @@ -28,6 +81,7 @@ class StateManager: log.debug('inserting state: {!r}', state) user_states[state.session_id] = state + self.states_raw[state.session_id] = state def fetch(self, user_id: int, session_id: str) -> GatewayState: """Fetch a state object from the manager. @@ -40,11 +94,20 @@ class StateManager: """ return self.states[user_id][session_id] + def fetch_raw(self, session_id: str) -> GatewayState: + """Fetch a single state given the Session ID.""" + return self.states_raw[session_id] + def remove(self, state): """Remove a state from the registry""" if not state: return + try: + self.states_raw.pop(state.session_id) + except KeyError: + pass + try: log.debug('removing state: {!r}', state) self.states[state.user_id].pop(state.session_id) @@ -100,3 +163,54 @@ class StateManager: states.extend(member_states) return states + + async def shutdown_single(self, state: GatewayState): + """Send OP Reconnect to a single connection.""" + websocket = state.ws + + await websocket.send({ + 'op': OP.RECONNECT + }) + + # wait 200ms + # so that the client has time to process + # our payload then close the connection + await asyncio.sleep(0.2) + + try: + # try to close the connection ourselves + await websocket.ws.close( + code=4000, + reason='litecord shutting down' + ) + except ConnectionClosed: + log.info('client {} already closed', state) + + def gen_close_tasks(self): + """Generate the tasks that will order the clients + to reconnect. + + This is required to be ran before :meth:`StateManager.close`, + since this function doesn't wait for the tasks to complete. + """ + + self.accept_new = False + + #: store the shutdown tasks + tasks = [] + + for state in self.states_raw.values(): + if not state.ws: + continue + + tasks.append( + self.shutdown_single(state) + ) + + log.info('made {} shutdown tasks', len(tasks)) + + return tasks + + def close(self): + """Close the state manager.""" + self.closed = True diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 9c6bd4d..8001a28 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -28,7 +28,8 @@ WebsocketProperties = collections.namedtuple( ) WebsocketObjects = collections.namedtuple( - 'WebsocketObjects', 'db state_manager storage loop dispatcher presence' + 'WebsocketObjects', ('db', 'state_manager', 'storage', + 'loop', 'dispatcher', 'presence', 'ratelimiter') ) @@ -44,8 +45,38 @@ def encode_etf(payload) -> str: return earl.pack(payload) +def _etf_decode_dict(data): + # NOTE: this is a very slow implementation to + # decode the dictionary. + + if isinstance(data, bytes): + return data.decode() + + if not isinstance(data, dict): + return data + + _copy = dict(data) + result = {} + + for key in _copy.keys(): + # assuming key is bytes rn. + new_k = key.decode() + + # maybe nested dicts, so... + result[new_k] = _etf_decode_dict(data[key]) + + return result + def decode_etf(data: bytes): - return earl.unpack(data) + res = earl.unpack(data) + + if isinstance(res, bytes): + return data.decode() + + if isinstance(res, dict): + return _etf_decode_dict(res) + + return res class GatewayWebsocket: @@ -108,6 +139,11 @@ class GatewayWebsocket: else: await self.ws.send(encoded.decode()) + def _check_ratelimit(self, key: str, ratelimit_key: str): + ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}') + bucket = ratelimit.get_bucket(ratelimit_key) + return bucket.update_rate_limit() + async def _hb_wait(self, interval: int): """Wait heartbeat""" # if the client heartbeats in time, @@ -312,6 +348,14 @@ class GatewayWebsocket: async def update_status(self, status: dict): """Update the status of the current websocket connection.""" + if not self.state: + return + + if self._check_ratelimit('presence', self.state.session_id): + # Presence Updates beyond the ratelimit + # are just silently dropped. + return + if status is None: status = { 'afk': False, @@ -365,6 +409,15 @@ class GatewayWebsocket: 'op': OP.HEARTBEAT_ACK, }) + async def _connect_ratelimit(self, user_id: int): + if self._check_ratelimit('connect', user_id): + await self.invalidate_session(False) + raise WebsocketClose(4009, 'You are being ratelimited.') + + if self._check_ratelimit('session', user_id): + await self.invalidate_session(False) + raise WebsocketClose(4004, 'Websocket Session Ratelimit reached.') + async def handle_2(self, payload: Dict[str, Any]): """Handle the OP 2 Identify packet.""" try: @@ -384,6 +437,8 @@ class GatewayWebsocket: except (Unauthorized, Forbidden): raise WebsocketClose(4004, 'Authentication failed') + await self._connect_ratelimit(user_id) + bot = await self.ext.db.fetchval(""" SELECT bot FROM users WHERE id = $1 @@ -641,9 +696,11 @@ class GatewayWebsocket: This is the known structure of GUILD_MEMBER_LIST_UPDATE: + group_id = 'online' | 'offline' | role_id (string) + sync_item = { 'group': { - 'id': string, // 'online' | 'offline' | any role id + 'id': group_id, 'count': num } } | { @@ -653,7 +710,7 @@ class GatewayWebsocket: list_op = 'SYNC' | 'INVALIDATE' | 'INSERT' | 'UPDATE' | 'DELETE' list_data = { - 'id': "everyone" // ?? + 'id': channel_id | 'everyone', 'guild_id': guild_id, 'ops': [ @@ -666,10 +723,10 @@ class GatewayWebsocket: // exists if op = 'SYNC' 'items': sync_item[], - // exists if op = 'INSERT' or 'DELETE' + // exists if op == 'INSERT' | 'DELETE' | 'UPDATE' 'index': num, - // exists if op = 'INSERT' + // exists if op == 'INSERT' | 'UPDATE' 'item': sync_item, } ], @@ -678,31 +735,11 @@ class GatewayWebsocket: // separately from the online list? 'groups': [ { - 'id': string // 'online' | 'offline' | any role id + 'id': group_id 'count': num }, ... ] } - - # Implementation defails. - - Lazy guilds are complicated to deal with in the backend level - as there are a lot of computation to be done for each request. - - The current implementation is rudimentary and does not account - for any roles inside the guild. - - A correct implementation would take account of roles and make - the correct groups on list_data: - - For each channel in lazy_request['channels']: - - get all roles that have Read Messages on the channel: - - Also fetch their member counts, as it'll be important - - with the role list, order them like you normally would - (by their role priority) - - based on the channel's range's min and max and the ordered - role list, you can get the roles wanted for your list_data reply. - - make new groups ONLY when the role is hoisted. """ data = payload['d'] @@ -713,65 +750,16 @@ class GatewayWebsocket: if guild_id not in gids: return - member_ids = await self.storage.get_member_ids(guild_id) - log.debug('lazy: loading {} members', len(member_ids)) + # make shard query + lazy_guilds = self.ext.dispatcher.backends['lazy_guild'] - # the current implementation is rudimentary and only - # generates two groups: online and offline, using - # PresenceManager.guild_presences to fill list_data. + for chan_id, ranges in data.get('channels', {}).items(): + chan_id = int(chan_id) + member_list = await lazy_guilds.get_gml(chan_id) - # this also doesn't take account the channels in lazy_request. - - guild_presences = await self.presence.guild_presences(member_ids, - guild_id) - - online = [{'member': p} - for p in guild_presences - if p['status'] == 'online'] - offline = [{'member': p} - for p in guild_presences - if p['status'] == 'offline'] - - log.debug('lazy: {} presences, online={}, offline={}', - len(guild_presences), - len(online), - len(offline)) - - # construct items in the WORST WAY POSSIBLE. - items = [{ - 'group': { - 'id': 'online', - 'count': len(online), - } - }] + online + [{ - 'group': { - 'id': 'offline', - 'count': len(offline), - } - }] + offline - - await self.dispatch('GUILD_MEMBER_LIST_UPDATE', { - 'id': 'everyone', - 'guild_id': data['guild_id'], - 'groups': [ - { - 'id': 'online', - 'count': len(online), - }, - { - 'id': 'offline', - 'count': len(offline), - } - ], - - 'ops': [ - { - 'range': [0, 99], - 'op': 'SYNC', - 'items': items - } - ] - }) + await member_list.shard_query( + self.state.session_id, ranges + ) async def process_message(self, payload): """Process a single message coming in from the client.""" @@ -788,17 +776,36 @@ class GatewayWebsocket: await handler(payload) + async def _msg_ratelimit(self): + if self._check_ratelimit('messages', self.state.session_id): + raise WebsocketClose(4008, 'You are being ratelimited.') + async def listen_messages(self): """Listen for messages coming in from the websocket.""" + + # close anyone trying to login while the + # server is shutting down + if self.ext.state_manager.closed: + raise WebsocketClose(4000, 'state manager closed') + + if not self.ext.state_manager.accept_new: + raise WebsocketClose(4000, 'state manager closed for new') + while True: message = await self.ws.recv() if len(message) > 4096: raise DecodeError('Payload length exceeded') + if self.state: + await self._msg_ratelimit() + payload = self.decoder(message) await self.process_message(payload) def _cleanup(self): + for task in self.wsp.tasks.values(): + task.cancel() + if self.state: self.ext.state_manager.remove(self.state) self.state.ws = None diff --git a/litecord/permissions.py b/litecord/permissions.py new file mode 100644 index 0000000..b63fa98 --- /dev/null +++ b/litecord/permissions.py @@ -0,0 +1,242 @@ +import ctypes + +from quart import current_app as app, request + +# so we don't keep repeating the same +# type for all the fields +_i = ctypes.c_uint8 + +class _RawPermsBits(ctypes.LittleEndianStructure): + """raw bitfield for discord's permission number.""" + _fields_ = [ + ('create_invites', _i, 1), + ('kick_members', _i, 1), + ('ban_members', _i, 1), + ('administrator', _i, 1), + ('manage_channels', _i, 1), + ('manage_guild', _i, 1), + ('add_reactions', _i, 1), + ('view_audit_log', _i, 1), + ('priority_speaker', _i, 1), + ('_unused1', _i, 1), + ('read_messages', _i, 1), + ('send_messages', _i, 1), + ('send_tts', _i, 1), + ('manage_messages', _i, 1), + ('embed_links', _i, 1), + ('attach_files', _i, 1), + ('read_history', _i, 1), + ('mention_everyone', _i, 1), + ('external_emojis', _i, 1), + ('_unused2', _i, 1), + ('connect', _i, 1), + ('speak', _i, 1), + ('mute_members', _i, 1), + ('deafen_members', _i, 1), + ('move_members', _i, 1), + ('use_voice_activation', _i, 1), + ('change_nickname', _i, 1), + ('manage_nicknames', _i, 1), + ('manage_roles', _i, 1), + ('manage_webhooks', _i, 1), + ('manage_emojis', _i, 1), + ] + + +class Permissions(ctypes.Union): + _fields_ = [ + ('bits', _RawPermsBits), + ('binary', ctypes.c_uint64), + ] + + def __init__(self, val: int): + self.binary = val + + def __repr__(self): + return f'' + + def __int__(self): + return self.binary + + def numby(self): + return self.binary + + +ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111) + + +async def get_role_perms(guild_id, role_id, storage=None) -> Permissions: + """Get the raw :class:`Permissions` object for a role.""" + if not storage: + storage = app.storage + + perms = await storage.db.fetchval(""" + SELECT permissions + FROM roles + WHERE guild_id = $1 AND id = $2 + """, guild_id, role_id) + + return Permissions(perms) + + +async def base_permissions(member_id, guild_id, storage=None) -> Permissions: + """Compute the base permissions for a given user. + + Base permissions are + (permissions from @everyone role) + + (permissions from any other role the member has) + + This will give ALL_PERMISSIONS if base permissions + has the Administrator bit set. + """ + + if not storage: + storage = app.storage + + owner_id = await storage.db.fetchval(""" + SELECT owner_id + FROM guilds + WHERE id = $1 + """, guild_id) + + if owner_id == member_id: + return ALL_PERMISSIONS + + # get permissions for @everyone + permissions = await get_role_perms(guild_id, guild_id, storage) + + role_ids = await storage.db.fetch(""" + SELECT role_id + FROM member_roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + role_perms = [] + + for row in role_ids: + rperm = await storage.db.fetchval(""" + SELECT permissions + FROM roles + WHERE id = $1 + """, row['role_id']) + + role_perms.append(rperm) + + for perm_num in role_perms: + permissions.binary |= perm_num + + if permissions.bits.administrator: + return ALL_PERMISSIONS + + return permissions + + +def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions: + # we make a copy of the binary representation + # so we don't modify the old perms in-place + # which could be an unwanted side-effect + result = perms.binary + + # negate the permissions that are denied + result &= ~overwrite['deny'] + + # combine the permissions that are allowed + result |= overwrite['allow'] + + return Permissions(result) + + +def overwrite_find_mix(perms: Permissions, overwrites: dict, + target_id: int) -> Permissions: + overwrite = overwrites.get(target_id) + + if overwrite: + # only mix if overwrite found + return overwrite_mix(perms, overwrite) + + return perms + + +async def role_permissions(guild_id: int, role_id: int, + channel_id: int, storage=None) -> Permissions: + """Get the permissions for a role, in relation to a channel""" + if not storage: + storage = app.storage + + perms = await get_role_perms(guild_id, role_id, storage) + + overwrite = await storage.db.fetchrow(""" + SELECT allow, deny + FROM channel_overwrites + WHERE channel_id = $1 AND target_type = $2 AND target_role = $3 + """, channel_id, 1, role_id) + + if overwrite: + perms = overwrite_mix(perms, overwrite) + + return perms + + +async def compute_overwrites(base_perms, user_id, channel_id: int, + guild_id: int = None, storage=None): + """Compute the permissions in the context of a channel.""" + if not storage: + storage = app.storage + + if base_perms.bits.administrator: + return ALL_PERMISSIONS + + perms = base_perms + + # list of overwrites + overwrites = await storage.chan_overwrites(channel_id) + + if not guild_id: + guild_id = await storage.guild_from_channel(channel_id) + + # make it a map for better usage + overwrites = {o['id']: o for o in overwrites} + + perms = overwrite_find_mix(perms, overwrites, guild_id) + + # apply role specific overwrites + allow, deny = 0, 0 + + # fetch roles from user and convert to int + role_ids = await storage.get_member_role_ids(guild_id, user_id) + role_ids = map(int, role_ids) + + # make the allow and deny binaries + for role_id in role_ids: + overwrite = overwrites.get(role_id) + if overwrite: + allow |= overwrite['allow'] + deny |= overwrite['deny'] + + # final step for roles: mix + perms = overwrite_mix(perms, { + 'allow': allow, + 'deny': deny + }) + + # apply member specific overwrites + perms = overwrite_find_mix(perms, overwrites, user_id) + + return perms + + +async def get_permissions(member_id, channel_id, *, storage=None): + """Get all the permissions for a user in a channel.""" + if not storage: + storage = app.storage + + guild_id = await storage.guild_from_channel(channel_id) + + # for non guild channels + if not guild_id: + return ALL_PERMISSIONS + + base_perms = await base_permissions(member_id, guild_id, storage) + + return await compute_overwrites(base_perms, member_id, + channel_id, guild_id, storage) diff --git a/litecord/presence.py b/litecord/presence.py index 3666f3e..67f8edd 100644 --- a/litecord/presence.py +++ b/litecord/presence.py @@ -1,8 +1,11 @@ from typing import List, Dict, Any from random import choice +from logbook import Logger from quart import current_app as app +log = Logger(__name__) + def status_cmp(status: str, other_status: str) -> bool: """Compare if `status` is better than the `other_status` @@ -100,20 +103,64 @@ class PresenceManager: game = state['game'] - await self.dispatcher.dispatch_guild( - guild_id, 'PRESENCE_UPDATE', { - 'user': member['user'], - 'roles': member['roles'], - 'guild_id': guild_id, + lazy_guild_store = self.dispatcher.backends['lazy_guild'] + lists = lazy_guild_store.get_gml_guild(guild_id) - 'status': state['status'], + # shards that are in lazy guilds with 'everyone' + # enabled + in_lazy = [] - # rich presence stuff - 'game': game, - 'activities': [game] if game else [] - } + for member_list in lists: + session_ids = await member_list.pres_update( + int(member['user']['id']), + { + 'roles': member['roles'], + 'status': state['status'], + 'game': game + } + ) + + log.debug('Lazy Dispatch to {}', + len(session_ids)) + + if member_list.channel_id == 'everyone': + in_lazy.extend(session_ids) + + pres_update_payload = { + 'user': member['user'], + 'roles': member['roles'], + 'guild_id': str(guild_id), + + 'status': state['status'], + + # rich presence stuff + 'game': game, + 'activities': [game] if game else [] + } + + def _sane_session(session_id): + state = self.state_manager.fetch_raw(session_id) + uid = int(member['user']['id']) + + if not state: + return False + + # we don't want to send a presence update + # to the same user + return (state.user_id != uid and + session_id not in in_lazy) + + # everyone not in lazy guild mode + # gets a PRESENCE_UPDATE + await self.dispatcher.dispatch_filter( + 'guild', guild_id, + _sane_session, + + 'PRESENCE_UPDATE', pres_update_payload ) + return in_lazy + async def dispatch_pres(self, user_id: int, state: dict): """Dispatch a new presence to all guilds the user is in. @@ -122,10 +169,12 @@ class PresenceManager: if state['status'] == 'invisible': state['status'] = 'offline' + # TODO: shard-aware guild_ids = await self.storage.get_user_guilds(user_id) for guild_id in guild_ids: - await self.dispatch_guild_pres(guild_id, user_id, state) + await self.dispatch_guild_pres( + guild_id, user_id, state) # dispatch to all friends that are subscribed to them user = await self.storage.get_user(user_id) diff --git a/litecord/pubsub/__init__.py b/litecord/pubsub/__init__.py index 7320867..31388de 100644 --- a/litecord/pubsub/__init__.py +++ b/litecord/pubsub/__init__.py @@ -3,7 +3,8 @@ from .member import MemberDispatcher from .user import UserDispatcher from .channel import ChannelDispatcher from .friend import FriendDispatcher +from .lazy_guild import LazyGuildDispatcher __all__ = ['GuildDispatcher', 'MemberDispatcher', 'UserDispatcher', 'ChannelDispatcher', - 'FriendDispatcher'] + 'FriendDispatcher', 'LazyGuildDispatcher'] diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index 3c3cb2c..621eeff 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -1,5 +1,4 @@ from typing import Any -from collections import defaultdict from logbook import Logger diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py index f65104d..c9da2e7 100644 --- a/litecord/pubsub/dispatcher.py +++ b/litecord/pubsub/dispatcher.py @@ -37,6 +37,14 @@ class Dispatcher: """Unsubscribe an elemtnt from the channel/key.""" raise NotImplementedError + async def dispatch_filter(self, _key, _func, *_args): + """Selectively dispatch to the list of subscribed users. + + The selection logic is completly arbitraty and up to the + Pub/Sub backend. + """ + raise NotImplementedError + async def dispatch(self, _key, *_args): """Dispatch an event to the given channel/key.""" raise NotImplementedError diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index a05373a..fcb0f05 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -1,4 +1,3 @@ -from collections import defaultdict from typing import Any from logbook import Logger @@ -47,6 +46,8 @@ class GuildDispatcher(DispatcherWithState): # when subbing a user to the guild, we should sub them # to every channel they have access to, in the guild. + # TODO: check for permissions + await self._chan_action('sub', guild_id, user_id) async def unsub(self, guild_id: int, user_id: int): @@ -56,9 +57,10 @@ class GuildDispatcher(DispatcherWithState): # same thing happening from sub() happens on unsub() await self._chan_action('unsub', guild_id, user_id) - async def dispatch(self, guild_id: int, - event: str, data: Any): - """Dispatch an event to all subscribers of the guild.""" + async def dispatch_filter(self, guild_id: int, func, + event: str, data: Any): + """Selectively dispatch to session ids that have + func(session_id) true.""" user_ids = self.state[guild_id] dispatched = 0 @@ -75,8 +77,22 @@ class GuildDispatcher(DispatcherWithState): await self.unsub(guild_id, user_id) continue + # filter the ones that matter + states = list(filter( + lambda state: func(state.session_id), states + )) + dispatched += await self._dispatch_states( states, event, data) log.info('Dispatched {} {!r} to {} states', guild_id, event, dispatched) + + async def dispatch(self, guild_id: int, + event: str, data: Any): + """Dispatch an event to all subscribers of the guild.""" + await self.dispatch_filter( + guild_id, + lambda sess_id: True, + event, data, + ) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py new file mode 100644 index 0000000..d94a0bf --- /dev/null +++ b/litecord/pubsub/lazy_guild.py @@ -0,0 +1,736 @@ +""" +Main code for Lazy Guild implementation in litecord. +""" +import pprint +from dataclasses import dataclass, asdict +from collections import defaultdict +from typing import Any, List, Dict, Union + +from logbook import Logger + +from litecord.pubsub.dispatcher import Dispatcher +from litecord.permissions import ( + Permissions, overwrite_find_mix, get_permissions, role_permissions +) +from litecord.utils import index_by_func + +log = Logger(__name__) + +GroupID = Union[int, str] +Presence = Dict[str, Any] + + +@dataclass +class GroupInfo: + """Store information about a specific group.""" + gid: GroupID + name: str + position: int + permissions: Permissions + + +@dataclass +class MemberList: + """Total information on the guild's member list.""" + groups: List[GroupInfo] = None + group_info: Dict[GroupID, GroupInfo] = None + data: Dict[GroupID, Presence] = None + overwrites: Dict[int, Dict[str, Any]] = None + + def __bool__(self): + """Return if the current member list is fully initialized.""" + list_dict = asdict(self) + return all(v is not None for v in list_dict.values()) + + def __iter__(self): + """Iterate over all groups in the correct order. + + Yields a tuple containing :class:`GroupInfo` and + the List[Presence] for the group. + """ + if not self.groups: + return + + for group in self.groups: + yield group, self.data[group.gid] + + +@dataclass +class Operation: + """Represents a member list operation.""" + list_op: str + params: Dict[str, Any] + + @property + def to_dict(self) -> dict: + res = { + 'op': self.list_op + } + + if self.list_op == 'SYNC': + res['items'] = self.params['items'] + + if self.list_op in ('SYNC', 'INVALIDATE'): + res['range'] = self.params['range'] + + if self.list_op in ('INSERT', 'DELETE', 'UPDATE'): + res['index'] = self.params['index'] + + if self.list_op in ('INSERT', 'UPDATE'): + res['item'] = self.params['item'] + + return res + + +def _to_simple_group(presence: dict) -> str: + """Return a simple group (not a role), given a presence.""" + return 'offline' if presence['status'] == 'offline' else 'online' + + +def display_name(member_nicks: Dict[str, str], presence: Presence) -> str: + """Return the display name of a presence. + + Used to sort groups. + """ + uid = presence['user']['id'] + + uname = presence['user']['username'] + nick = member_nicks.get(uid) + + return nick or uname + + +class GuildMemberList: + """This class stores the current member list information + for a guild (by channel). + + As channels can have different sets of roles that can + read them and so, different lists, this is more of a + "channel member list" than a guild member list. + + Attributes + ---------- + main_lg: LazyGuildDispatcher + Main instance of :class:`LazyGuildDispatcher`, + so that we're able to use things such as :class:`Storage`. + guild_id: int + The Guild ID this instance is referring to. + channel_id: int + The Channel ID this instance is referring to. + member_list: List + The actual member list information. + state: set + The set of session IDs that are subscribed to the guild. + + User IDs being used as the identifier in GuildMemberList + is a wrong assumption. It is true Discord rolled out + lazy guilds to all of the userbase, but users that are bots, + for example, can still rely on PRESENCE_UPDATEs. + """ + def __init__(self, guild_id: int, + channel_id: int, main_lg): + self.guild_id = guild_id + self.channel_id = channel_id + + self.main = main_lg + self.list = MemberList(None, None, None, None) + + #: store the states that are subscribed to the list + # type is{session_id: set[list]} + self.state = defaultdict(set) + + @property + def storage(self): + """Get the global :class:`Storage` instance.""" + return self.main.app.storage + + @property + def presence(self): + """Get the global :class:`PresenceManager` instance.""" + return self.main.app.presence + + @property + def state_man(self): + """Get the global :class:`StateManager` instance.""" + return self.main.app.state_manager + + @property + def list_id(self): + """get the id of the member list.""" + return ('everyone' + if self.channel_id == self.guild_id + else str(self.channel_id)) + + def _set_empty_list(self): + """Set the member list as being empty.""" + self.list = MemberList(None, None, None, None) + + async def _init_check(self): + """Check if the member list is initialized before + messing with it.""" + if not self.list: + await self._init_member_list() + + async def _fetch_overwrites(self): + overwrites = await self.storage.chan_overwrites(self.channel_id) + overwrites = {int(ov['id']): ov for ov in overwrites} + self.list.overwrites = overwrites + + def _calc_member_group(self, roles: List[int], status: str): + """Calculate the best fitting group for a member, + given their roles and their current status.""" + try: + # the first group in the list + # that the member is entitled to is + # the selected group for the member. + group_id = next(g.gid for g in self.list.groups + if g.gid in roles) + except StopIteration: + # no group was found, so we fallback + # to simple group" + group_id = _to_simple_group({'status': status}) + + return group_id + + async def get_roles(self) -> List[GroupInfo]: + """Get role information, but only: + - the ID + - the name + - the position + - the permissions + + of all HOISTED roles AND roles that + have permissions to read the channel + being referred to this :class:`GuildMemberList` + instance. + + The list is sorted by position. + """ + roledata = await self.storage.db.fetch(""" + SELECT id, name, hoist, position, permissions + FROM roles + WHERE guild_id = $1 + """, self.guild_id) + + hoisted = [ + GroupInfo(row['id'], row['name'], + row['position'], row['permissions']) + for row in roledata if row['hoist'] + ] + + # sort role list by position + hoisted = sorted(hoisted, key=lambda group: group.position) + + # we need to store them for later on + # for members + await self._fetch_overwrites() + + def _can_read_chan(group: GroupInfo): + # get the base role perms + role_perms = group.permissions + + # then the final perms for that role if + # any overwrite exists in the channel + final_perms = overwrite_find_mix( + role_perms, self.list.overwrites, group.gid) + + # update the group's permissions + # with the mixed ones + group.permissions = final_perms + + # if the role can read messages, then its + # part of the group. + return final_perms.bits.read_messages + + return list(filter(_can_read_chan, hoisted)) + + async def set_groups(self): + """Get the groups for the member list.""" + role_groups = await self.get_roles() + role_ids = [g.gid for g in role_groups] + + self.list.groups = role_ids + ['online', 'offline'] + + # inject default groups 'online' and 'offline' + self.list.groups = role_ids + [ + GroupInfo('online', 'online', -1, -1), + GroupInfo('offline', 'offline', -1, -1) + ] + self.list.group_info = {g.gid: g for g in role_groups} + + async def get_group(self, member_id: int, + roles: List[Union[str, int]], + status: str) -> int: + """Return a fitting group ID for the user.""" + member_roles = list(map(int, roles)) + + # get the member's permissions relative to the channel + # (accounting for channel overwrites) + member_perms = await get_permissions( + member_id, self.channel_id, storage=self.storage) + + if not member_perms.bits.read_messages: + return None + + # if the member is offline, we + # default give them the offline group. + group_id = ('offline' if status == 'offline' + else self._calc_member_group(member_roles, status)) + + return group_id + + async def _pass_1(self, guild_presences: List[Presence]): + """First pass on generating the member list. + + This assigns all presences a single group. + """ + for presence in guild_presences: + member_id = int(presence['user']['id']) + + group_id = await self.get_group( + member_id, presence['roles'], presence['status'] + ) + + self.list.data[group_id].append(presence) + + async def get_member_nicks_dict(self) -> dict: + """Get a dictionary with nickname information.""" + members = await self.storage.get_member_data(self.guild_id) + + # make a dictionary of member ids to nicknames + # so we don't need to keep querying the db on + # every loop iteration + member_nicks = {m['user']['id']: m.get('nick') + for m in members} + + return member_nicks + + async def _sort_groups(self): + member_nicks = await self.get_member_nicks_dict() + + for group_members in self.list.data.values(): + + # this should update the list in-place + group_members.sort( + key=lambda p: display_name(member_nicks, p)) + + async def _init_member_list(self): + """Generate the main member list with groups.""" + member_ids = await self.storage.get_member_ids(self.guild_id) + + guild_presences = await self.presence.guild_presences( + member_ids, self.guild_id) + + await self.set_groups() + + log.debug('{} presences, {} groups', + len(guild_presences), + len(self.list.groups)) + + self.list.data = {group.gid: [] for group in self.list.groups} + + # first pass: set which presences + # go to which groups + await self._pass_1(guild_presences) + + # second pass: sort each group's members + # by the display name + await self._sort_groups() + + @property + def items(self) -> list: + """Main items list.""" + + # TODO: maybe make this stored in the list + # so we don't need to keep regenning? + + if not self.list: + return [] + + res = [] + + # NOTE: maybe use map()? + for group, presences in self.list: + res.append({ + 'group': { + 'id': group.gid, + 'count': len(presences), + } + }) + + for presence in presences: + res.append({ + 'member': presence + }) + + return res + + async def sub(self, _session_id: str): + """Subscribe a shard to the member list.""" + await self._init_check() + + def unsub(self, session_id: str): + """Unsubscribe a shard from the member list""" + try: + self.state.pop(session_id) + except KeyError: + pass + + # once we reach 0 subscribers, + # we drop the current member list we have (for memory) + # but keep the GuildMemberList running (as + # uninitialized) for a future subscriber. + + if not self.state: + self._set_empty_list() + + def get_state(self, session_id: str): + try: + state = self.state_man.fetch_raw(session_id) + return state + except KeyError: + self.unsub(session_id) + return + + async def _dispatch_sess(self, session_ids: List[str], + operations: List[Operation]): + + # construct the payload to dispatch + payload = { + 'id': self.list_id, + 'guild_id': str(self.guild_id), + + 'groups': [ + { + 'count': len(presences), + 'id': group.gid + } for group, presences in self.list + ], + + 'ops': [ + operation.to_dict + for operation in operations + ] + } + + states = map(self.get_state, session_ids) + states = filter(lambda state: state is not None, states) + + dispatched = [] + + for state in states: + await state.ws.dispatch( + 'GUILD_MEMBER_LIST_UPDATE', payload) + + dispatched.append(state.session_id) + + return dispatched + + async def shard_query(self, session_id: str, ranges: list): + """Send a GUILD_MEMBER_LIST_UPDATE event + for a shard that is querying about the member list. + + Paramteters + ----------- + session_id: str + The Session ID querying information. + channel_id: int + The Channel ID that we want information on. + ranges: List[List[int]] + ranges of the list that we want. + """ + + # a guild list with a channel id of the guild + # represents the 'everyone' global list. + list_id = self.list_id + + # if everyone can read the channel, + # we direct the request to the 'everyone' gml instance + # instead of the current one. + everyone_perms = await role_permissions( + self.guild_id, + self.guild_id, + self.channel_id, + storage=self.storage + ) + + if everyone_perms.bits.read_messages and list_id != 'everyone': + everyone_gml = await self.main.get_gml(self.guild_id) + + return await everyone_gml.shard_query( + session_id, ranges + ) + + await self._init_check() + + ops = [] + + for start, end in ranges: + itemcount = end - start + + # ignore incorrect ranges + if itemcount < 0: + continue + + self.state[session_id].add((start, end)) + + ops.append(Operation('SYNC', { + 'range': [start, end], + 'items': self.items[start:end] + })) + + await self._dispatch_sess([session_id], ops) + + def get_item_index(self, user_id: Union[str, int]): + """Get the item index a user is on.""" + def _get_id(item): + # item can be a group item or a member item + return item.get('member', {}).get('user', {}).get('id') + + # get the updated item's index + return index_by_func( + lambda p: _get_id(p) == str(user_id), + self.items + ) + + def state_is_subbed(self, item_index, session_id: str) -> bool: + """Return if a state's ranges include the given + item index.""" + + ranges = self.state[session_id] + + for range_start, range_end in ranges: + if range_start <= item_index <= range_end: + return True + + return False + + def get_subs(self, item_index: int) -> filter: + """Get the list of subscribed states to a given item.""" + return filter( + lambda sess_id: self.state_is_subbed(item_index, sess_id), + self.state.keys() + ) + + async def _pres_update_simple(self, user_id: int): + item_index = self.get_item_index(user_id) + + if item_index is None: + log.warning('lazy guild got invalid pres update uid={}', + user_id) + return [] + + item = self.items[item_index] + session_ids = self.get_subs(item_index) + + # simple update means we just give an UPDATE + # operation + return await self._dispatch_sess( + session_ids, + [ + Operation('UPDATE', { + 'index': item_index, + 'item': item, + }) + ] + ) + + async def _pres_update_complex(self, user_id: int, + old_group: str, old_index: int, + new_group: str): + """Move a member between groups.""" + log.debug('complex update: uid={} old={} old_idx={} new={}', + user_id, old_group, old_index, new_group) + old_group_presences = self.list.data[old_group] + old_item_index = self.get_item_index(user_id) + + # make a copy of current presence to insert in the new group + current_presence = dict(old_group_presences[old_index]) + + # step 1: remove the old presence (old_index is relative + # to the group, and not the items list) + del old_group_presences[old_index] + + # we need to insert current_presence to the new group + # but we also need to calculate its index to insert on. + presences = self.list.data[new_group] + + best_index = 0 + member_nicks = await self.get_member_nicks_dict() + current_name = display_name(member_nicks, current_presence) + + # go through each one until we find the best placement + for presence in presences: + name = display_name(member_nicks, presence) + + print(name, current_name, name < current_name) + + # TODO: check if this works + if name < current_name: + break + + best_index += 1 + + # insert the presence at the index + presences.insert(best_index + 1, current_presence) + + new_item_index = self.get_item_index(user_id) + + log.debug('assigned new item index {} to uid {}', + new_item_index, user_id) + + session_ids_old = self.get_subs(old_item_index) + session_ids_new = self.get_subs(new_item_index) + + # dispatch events to both the old states and + # new states. + return await self._dispatch_sess( + # inefficient, but necessary since we + # want to merge both session ids. + list(session_ids_old) + list(session_ids_new), + [ + Operation('DELETE', { + 'index': old_item_index, + }), + + Operation('INSERT', { + 'index': new_item_index, + 'item': { + 'member': current_presence + } + }) + ] + ) + + async def pres_update(self, user_id: int, + partial_presence: Dict[str, Any]): + """Update a presence inside the member list. + + There are 4 types of updates that can happen for a user in a group: + - from 'offline' to any + - from any to 'offline' + - from any to any + - from G to G (with G being any group) + + any: 'online' | role_id + + All first, second, and third updates are 'complex' updates, + which means we'll have to change the group the user is on + to account for them. + + The fourth is a 'simple' change, since we're not changing + the group a user is on, and so there's less overhead + involved. + """ + await self._init_check() + + old_group, old_index, old_presence = None, None, None + + for group, presences in self.list: + p_idx = index_by_func( + lambda p: p['user']['id'] == str(user_id), + presences) + + log.debug('p_idx for group {!r} = {}', + group.gid, p_idx) + + if p_idx is None: + log.debug('skipping group {}', group) + continue + + # make a copy since we're modifying in-place + old_group = group.gid + old_index = p_idx + old_presence = dict(presences[p_idx]) + + # be ready if it is a simple update + presences[p_idx].update(partial_presence) + break + + if not old_group: + log.warning('pres update with unknown old group uid={}', + user_id) + return [] + + roles = partial_presence.get('roles', old_presence['roles']) + new_status = partial_presence.get('status', old_presence['status']) + + new_group = await self.get_group(user_id, roles, new_status) + + log.debug('pres update: gid={} cid={} old_g={} new_g={}', + self.guild_id, self.channel_id, old_group, new_group) + + # if we're going to the same group, + # treat this as a simple update + if old_group == new_group: + return await self._pres_update_simple(user_id) + + return await self._pres_update_complex( + user_id, old_group, old_index, new_group) + + async def dispatch(self, event: str, data: Any): + """Modify the member list and dispatch the respective + events to subscribed shards. + + GuildMemberList stores the current guilds' list + in its :attr:`GuildMemberList.list` attribute, + with that attribute being modified via different + calls to :meth:`GuildMemberList.dispatch` + """ + + # if no subscribers, drop event + if not self.list: + return + + +class LazyGuildDispatcher(Dispatcher): + """Main class holding the member lists for lazy guilds.""" + # channel ids + KEY_TYPE = int + + # the session ids subscribing to channels + VAL_TYPE = str + + def __init__(self, main): + super().__init__(main) + + self.storage = main.app.storage + + # {chan_id: gml, ...} + self.state = {} + + #: store which guilds have their + # respective GMLs + # {guild_id: [chan_id, ...], ...} + self.guild_map = defaultdict(list) + + async def get_gml(self, channel_id: int): + """Get a guild list for a channel ID, + generating it if it doesn't exist.""" + try: + return self.state[channel_id] + except KeyError: + guild_id = await self.storage.guild_from_channel( + channel_id + ) + + # if we don't find a guild, we just + # set it the same as the channel. + if not guild_id: + guild_id = channel_id + + gml = GuildMemberList(guild_id, channel_id, self) + self.state[channel_id] = gml + self.guild_map[guild_id].append(channel_id) + return gml + + def get_gml_guild(self, guild_id: int) -> List[GuildMemberList]: + """Get all member lists for a given guild.""" + return list(map( + self.state.get, + self.guild_map[guild_id] + )) + + async def unsub(self, chan_id, session_id): + gml = await self.get_gml(chan_id) + gml.unsub(session_id) diff --git a/litecord/ratelimits/bucket.py b/litecord/ratelimits/bucket.py new file mode 100644 index 0000000..dabb0ae --- /dev/null +++ b/litecord/ratelimits/bucket.py @@ -0,0 +1,113 @@ +""" +main litecord ratelimiting code + + This code was copied from elixire's ratelimiting, + which in turn is a work on top of discord.py's ratelimiting. +""" +import time + + +class RatelimitBucket: + """Main ratelimit bucket class.""" + def __init__(self, tokens, second): + self.requests = tokens + self.second = second + + self._window = 0.0 + self._tokens = self.requests + self.retries = 0 + self._last = 0.0 + + def get_tokens(self, current): + """Get the current amount of available tokens.""" + if not current: + current = time.time() + + # by default, use _tokens + tokens = self._tokens + + # if current timestamp is above _window + seconds + # reset tokens to self.requests (default) + if current > self._window + self.second: + tokens = self.requests + + return tokens + + def update_rate_limit(self): + """Update current ratelimit state.""" + current = time.time() + self._last = current + self._tokens = self.get_tokens(current) + + # we are using the ratelimit for the first time + # so set current ratelimit window to right now + if self._tokens == self.requests: + self._window = current + + # Are we currently ratelimited? + if self._tokens == 0: + self.retries += 1 + return self.second - (current - self._window) + + # if not ratelimited, remove a token + self.retries = 0 + self._tokens -= 1 + + # if we got ratelimited after that token removal, + # set window to now + if self._tokens == 0: + self._window = current + + def reset(self): + """Reset current ratelimit to default state.""" + self._tokens = self.requests + self._last = 0.0 + self.retries = 0 + + def copy(self): + """Create a copy of this ratelimit. + + Used to manage multiple ratelimits to users. + """ + return RatelimitBucket(self.requests, + self.second) + + def __repr__(self): + return (f'') + + +class Ratelimit: + """Manages buckets.""" + def __init__(self, tokens, second, keys=None): + self._cache = {} + if keys is None: + keys = tuple() + self.keys = keys + self._cooldown = RatelimitBucket(tokens, second) + + def __repr__(self): + return (f'') + + def _verify_cache(self): + current = time.time() + dead_keys = [k for k, v in self._cache.items() + if current > v._last + v.second] + + for k in dead_keys: + del self._cache[k] + + def get_bucket(self, key) -> RatelimitBucket: + if not self._cooldown: + return None + + self._verify_cache() + + if key not in self._cache: + bucket = self._cooldown.copy() + self._cache[key] = bucket + else: + bucket = self._cache[key] + + return bucket diff --git a/litecord/ratelimits/handler.py b/litecord/ratelimits/handler.py new file mode 100644 index 0000000..9a8987e --- /dev/null +++ b/litecord/ratelimits/handler.py @@ -0,0 +1,83 @@ +from quart import current_app as app, request, g + +from litecord.errors import Ratelimited +from litecord.auth import token_check, Unauthorized + + +async def _check_bucket(bucket): + retry_after = bucket.update_rate_limit() + + request.bucket = bucket + + if retry_after: + request.retry_after = retry_after + + raise Ratelimited('You are being rate limited.', { + 'retry_after': int(retry_after * 1000), + 'global': request.bucket_global, + }) + + +async def _handle_global(ratelimit): + """Global ratelimit is per-user.""" + try: + user_id = await token_check() + except Unauthorized: + user_id = request.remote_addr + + request.bucket_global = True + bucket = ratelimit.get_bucket(user_id) + await _check_bucket(bucket) + + +async def _handle_specific(ratelimit): + try: + user_id = await token_check() + except Unauthorized: + user_id = request.remote_addr + + # construct the key based on the ratelimit.keys + keys = ratelimit.keys + + # base key is the user id + key_components = [f'user_id:{user_id}'] + + for key in keys: + val = request.view_args[key] + key_components.append(f'{key}:{val}') + + bucket_key = ':'.join(key_components) + bucket = ratelimit.get_bucket(bucket_key) + await _check_bucket(bucket) + + +async def ratelimit_handler(): + """Main ratelimit handler. + + Decides on which ratelimit to use. + """ + rule = request.url_rule + + if rule is None: + return await _handle_global( + app.ratelimiter.global_bucket + ) + + # rule.endpoint is composed of '.' + # and so we can use that to make routes with different + # methods have different ratelimits + rule_path = rule.endpoint + + # some request ratelimit context. + # TODO: maybe put those in a namedtuple or contextvar of sorts? + request.bucket = None + request.retry_after = None + request.bucket_global = False + + try: + ratelimit = app.ratelimiter.get_ratelimit(rule_path) + await _handle_specific(ratelimit) + except KeyError: + await _handle_global( + app.ratelimiter.global_bucket + ) diff --git a/litecord/ratelimits/main.py b/litecord/ratelimits/main.py new file mode 100644 index 0000000..10d219b --- /dev/null +++ b/litecord/ratelimits/main.py @@ -0,0 +1,56 @@ +from litecord.ratelimits.bucket import Ratelimit + +""" +REST: + POST Message | 5/5s | per-channel + DELETE Message | 5/1s | per-channel + PUT/DELETE Reaction | 1/0.25s | per-channel + PATCH Member | 10/10s | per-guild + PATCH Member Nick | 1/1s | per-guild + PATCH Username | 2/3600s | per-account + |All Requests| | 50/1s | per-account +WS: + Gateway Connect | 1/5s | per-account + Presence Update | 5/60s | per-session + |All Sent Messages| | 120/60s | per-session +""" + +REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id')) + +RATELIMITS = { + 'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')), + 'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')), + + # all of those share the same bucket. + 'channel_reactions.add_reaction': REACTION_BUCKET, + 'channel_reactions.remove_own_reaction': REACTION_BUCKET, + 'channel_reactions.remove_user_reaction': REACTION_BUCKET, + + 'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')), + 'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')), + + # this only applies to username. + # 'users.patch_me': Ratelimit(2, 3600), + + '_ws.connect': Ratelimit(1, 5), + '_ws.presence': Ratelimit(5, 60), + '_ws.messages': Ratelimit(120, 60), + + # 1000 / 4h for new session issuing + '_ws.session': Ratelimit(1000, 14400) +} + +class RatelimitManager: + """Manager for the bucket managers""" + def __init__(self): + self._ratelimiters = {} + self.global_bucket = Ratelimit(50, 1) + self._fill_rtl() + + def _fill_rtl(self): + for path, rtl in RATELIMITS.items(): + self._ratelimiters[path] = rtl + + def get_ratelimit(self, key: str) -> Ratelimit: + """Get the :class:`Ratelimit` instance for a given path.""" + return self._ratelimiters.get(key, self.global_bucket) diff --git a/litecord/schemas.py b/litecord/schemas.py index 1c3eda3..e74dda0 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -1,11 +1,16 @@ import re +from typing import Union, Dict, List, Any from cerberus import Validator from logbook import Logger from .errors import BadRequest -from .enums import ActivityType, StatusType, ExplicitFilter, \ - RelationshipType, MessageNotifications +from .permissions import Permissions +from .types import Color +from .enums import ( + ActivityType, StatusType, ExplicitFilter, RelationshipType, + MessageNotifications, ChannelType, VerificationLevel +) log = Logger(__name__) @@ -24,13 +29,21 @@ EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M) ANIMOJI_MENTION = re.compile(r'', re.A | re.M) +def _in_enum(enum, value: int): + try: + enum(value) + return True + except ValueError: + return False + + class LitecordValidator(Validator): def _validate_type_username(self, value: str) -> bool: """Validate against the username regex.""" return bool(USERNAME_REGEX.match(value)) def _validate_type_email(self, value: str) -> bool: - """Validate against the username regex.""" + """Validate against the email regex.""" return bool(EMAIL_REGEX.match(value)) def _validate_type_b64_icon(self, value: str) -> bool: @@ -56,11 +69,17 @@ class LitecordValidator(Validator): def _validate_type_voice_region(self, value: str) -> bool: # TODO: complete this list - return value in ('brazil', 'us-east', 'us-west', 'us-south', 'russia') + return value.lower() in ('brazil', 'us-east', 'us-west', 'us-south', 'russia') + + def _validate_type_verification_level(self, value: int) -> bool: + return _in_enum(VerificationLevel, value) def _validate_type_activity_type(self, value: int) -> bool: return value in ActivityType.values() + def _validate_type_channel_type(self, value: int) -> bool: + return value in ChannelType.values() + def _validate_type_status_external(self, value: str) -> bool: statuses = StatusType.values() @@ -94,11 +113,31 @@ class LitecordValidator(Validator): return val in MessageNotifications.values() + def _validate_type_guild_name(self, value: str) -> bool: + return 2 <= len(value) <= 100 -def validate(reqjson, schema, raise_err: bool = True): + def _validate_type_role_name(self, value: str) -> bool: + return 1 <= len(value) <= 100 + + def _validate_type_channel_name(self, value: str) -> bool: + # for now, we'll use the same validation for guild_name + return self._validate_type_guild_name(value) + + +def validate(reqjson: Union[Dict, List], schema: Dict, + raise_err: bool = True) -> Union[Dict, List]: + """Validate a given document (user-input) and give + the correct document as a result. + """ validator = LitecordValidator(schema) - if not validator.validate(reqjson): + try: + valid = validator.validate(reqjson) + except Exception: + log.exception('Error while validating') + raise Exception(f'Error while validating: {reqjson}') + + if not valid: errs = validator.errors log.warning('Error validating doc {!r}: {!r}', reqjson, errs) @@ -146,16 +185,55 @@ USER_UPDATE = { } +PARTIAL_ROLE_GUILD_CREATE = { + 'type': 'dict', + 'schema': { + 'name': {'type': 'role_name'}, + 'color': {'type': 'number', 'default': 0}, + 'hoist': {'type': 'boolean', 'default': False}, + + # NOTE: no position on partial role (on guild create) + + 'permissions': {'coerce': Permissions, 'required': False}, + 'mentionable': {'type': 'boolean', 'default': False}, + } +} + +PARTIAL_CHANNEL_GUILD_CREATE = { + 'type': 'dict', + 'schema': { + 'name': {'type': 'channel_name'}, + 'type': {'type': 'channel_type'}, + } +} + +GUILD_CREATE = { + 'name': {'type': 'guild_name'}, + 'region': {'type': 'voice_region'}, + 'icon': {'type': 'b64_icon', 'required': False, 'nullable': True}, + + 'verification_level': { + 'type': 'verification_level', 'default': 0}, + 'default_message_notifications': { + 'type': 'msg_notifications', 'default': 0}, + 'explicit_content_filter': { + 'type': 'explicit', 'default': 0}, + + 'roles': { + 'type': 'list', 'required': False, + 'schema': PARTIAL_ROLE_GUILD_CREATE}, + 'channels': { + 'type': 'list', 'default': [], 'schema': PARTIAL_CHANNEL_GUILD_CREATE}, +} + GUILD_UPDATE = { 'name': { - 'type': 'string', - 'minlength': 2, - 'maxlength': 100, + 'type': 'guild_name', 'required': False }, 'region': {'type': 'voice_region', 'required': False}, - 'icon': {'type': 'icon', 'required': False}, + 'icon': {'type': 'b64_icon', 'required': False}, 'verification_level': {'type': 'verification_level', 'required': False}, 'default_message_notifications': { @@ -173,13 +251,93 @@ GUILD_UPDATE = { } +CHAN_OVERWRITE = { + 'id': {'coerce': int}, + 'type': {'type': 'string', 'allowed': ['role', 'member']}, + 'allow': {'coerce': Permissions}, + 'deny': {'coerce': Permissions} +} + + +CHAN_UPDATE = { + 'name': { + 'type': 'string', 'minlength': 2, + 'maxlength': 100, 'required': False}, + + 'position': {'coerce': int, 'required': False}, + + 'topic': { + 'type': 'string', 'minlength': 0, + 'maxlength': 1024, 'required': False}, + + 'nsfw': {'type': 'boolean', 'required': False}, + 'rate_limit_per_user': { + 'coerce': int, 'min': 0, + 'max': 120, 'required': False}, + + 'bitrate': { + 'coerce': int, 'min': 8000, + + # NOTE: 'max' is 96000 for non-vip guilds + 'max': 128000, 'required': False}, + + 'user_limit': { + # user_limit being 0 means infinite. + 'coerce': int, 'min': 0, + 'max': 99, 'required': False + }, + + 'permission_overwrites': { + 'type': 'list', + 'schema': {'type': 'dict', 'schema': CHAN_OVERWRITE}, + 'required': False + }, + + 'parent_id': {'coerce': int, 'required': False, 'nullable': True} + + +} + + +ROLE_CREATE = { + 'name': {'type': 'string', 'default': 'new role'}, + 'permissions': {'coerce': Permissions, 'nullable': True}, + 'color': {'coerce': Color, 'default': 0}, + 'hoist': {'type': 'boolean', 'default': False}, + 'mentionable': {'type': 'boolean', 'default': False}, +} + +ROLE_UPDATE = { + 'name': {'type': 'string', 'required': False}, + 'permissions': {'coerce': Permissions, 'required': False}, + 'color': {'coerce': Color, 'required': False}, + 'hoist': {'type': 'boolean', 'required': False}, + 'mentionable': {'type': 'boolean', 'required': False}, +} + + +ROLE_UPDATE_POSITION = { + 'roles': { + 'type': 'list', + 'schema': { + 'type': 'dict', + 'schema': { + 'id': {'coerce': int}, + 'position': {'coerce': int}, + }, + } + } +} + + MEMBER_UPDATE = { 'nick': { - 'type': 'nickname', + 'type': 'username', 'minlength': 1, 'maxlength': 100, 'required': False, }, - 'roles': {'type': 'list', 'required': False}, + 'roles': {'type': 'list', 'required': False, + 'schema': {'coerce': int}}, 'mute': {'type': 'boolean', 'required': False}, 'deaf': {'type': 'boolean', 'required': False}, 'channel_id': {'type': 'snowflake', 'required': False}, @@ -196,57 +354,60 @@ MESSAGE_CREATE = { GW_ACTIVITY = { - 'name': {'type': 'string', 'required': True}, - 'type': {'type': 'activity_type', 'required': True}, + 'type': 'dict', + 'schema': { + 'name': {'type': 'string', 'required': True}, + 'type': {'type': 'activity_type', 'required': True}, - 'url': {'type': 'string', 'required': False, 'nullable': True}, + 'url': {'type': 'string', 'required': False, 'nullable': True}, - 'timestamps': { - 'type': 'dict', - 'required': False, - 'schema': { - 'start': {'type': 'number', 'required': True}, - 'end': {'type': 'number', 'required': True}, + 'timestamps': { + 'type': 'dict', + 'required': False, + 'schema': { + 'start': {'type': 'number', 'required': True}, + 'end': {'type': 'number', 'required': False}, + }, }, - }, - 'application_id': {'type': 'snowflake', 'required': False, - 'nullable': False}, - 'details': {'type': 'string', 'required': False, 'nullable': True}, - 'state': {'type': 'string', 'required': False, 'nullable': True}, + 'application_id': {'type': 'snowflake', 'required': False, + 'nullable': False}, + 'details': {'type': 'string', 'required': False, 'nullable': True}, + 'state': {'type': 'string', 'required': False, 'nullable': True}, - 'party': { - 'type': 'dict', - 'required': False, - 'schema': { - 'id': {'type': 'snowflake', 'required': False}, - 'size': {'type': 'list', 'required': False}, - } - }, + 'party': { + 'type': 'dict', + 'required': False, + 'schema': { + 'id': {'type': 'snowflake', 'required': False}, + 'size': {'type': 'list', 'required': False}, + } + }, - 'assets': { - 'type': 'dict', - 'required': False, - 'schema': { - 'large_image': {'type': 'snowflake', 'required': False}, - 'large_text': {'type': 'string', 'required': False}, - 'small_image': {'type': 'snowflake', 'required': False}, - 'small_text': {'type': 'string', 'required': False}, - } - }, + 'assets': { + 'type': 'dict', + 'required': False, + 'schema': { + 'large_image': {'type': 'snowflake', 'required': False}, + 'large_text': {'type': 'string', 'required': False}, + 'small_image': {'type': 'snowflake', 'required': False}, + 'small_text': {'type': 'string', 'required': False}, + } + }, - 'secrets': { - 'type': 'dict', - 'required': False, - 'schema': { - 'join': {'type': 'string', 'required': False}, - 'spectate': {'type': 'string', 'required': False}, - 'match': {'type': 'string', 'required': False}, - } - }, + 'secrets': { + 'type': 'dict', + 'required': False, + 'schema': { + 'join': {'type': 'string', 'required': False}, + 'spectate': {'type': 'string', 'required': False}, + 'match': {'type': 'string', 'required': False}, + } + }, - 'instance': {'type': 'boolean', 'required': False}, - 'flags': {'type': 'number', 'required': False}, + 'instance': {'type': 'boolean', 'required': False}, + 'flags': {'type': 'number', 'required': False}, + } } GW_STATUS_UPDATE = { @@ -335,6 +496,8 @@ USER_SETTINGS = { 'show_current_game': {'type': 'boolean', 'required': False}, 'timezone_offset': {'type': 'number', 'required': False}, + + 'status': {'type': 'status_external', 'required': False} } RELATIONSHIP = { @@ -395,3 +558,7 @@ GUILD_SETTINGS = { 'required': False, } } + +GUILD_PRUNE = { + 'days': {'type': 'number', 'coerce': int, 'min': 1} +} diff --git a/litecord/storage.py b/litecord/storage.py index a7e37eb..b78c2b0 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -5,6 +5,9 @@ from logbook import Logger from .enums import ChannelType, RelationshipType from .schemas import USER_MENTION, ROLE_MENTION +from litecord.blueprints.channel.reactions import ( + emoji_info_from_str, EmojiType, emoji_sql, partial_emoji +) log = Logger(__name__) @@ -163,17 +166,40 @@ class Storage: WHERE guild_id = $1 and user_id = $2 """, guild_id, member_id) - async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: - members_roles = await self.db.fetch(""" + async def get_member_role_ids(self, guild_id: int, + member_id: int) -> List[int]: + """Get a list of role IDs that are on a member.""" + roles = await self.db.fetch(""" SELECT role_id::text FROM member_roles WHERE guild_id = $1 AND user_id = $2 """, guild_id, member_id) + roles = [r['role_id'] for r in roles] + + try: + roles.remove(str(guild_id)) + except ValueError: + # if the @everyone role isn't in, we add it + # to member_roles automatically (it won't + # be shown on the API, though). + await self.db.execute(""" + INSERT INTO member_roles (user_id, guild_id, role_id) + VALUES ($1, $2, $3) + """, member_id, guild_id, guild_id) + + return list(map(str, roles)) + + async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: + roles = await self.get_member_role_ids(guild_id, member_id) return { 'user': await self.get_user(member_id), 'nick': row['nickname'], - 'roles': [row[0] for row in members_roles], + + # we don't send the @everyone role's id to + # the user since it is known that everyone has + # that role. + 'roles': roles, 'joined_at': row['joined_at'].isoformat(), 'deaf': row['deafened'], 'mute': row['muted'], @@ -289,7 +315,7 @@ class Storage: WHERE channels.id = $1 """, channel_id) - async def _chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]: + async def chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]: overwrite_rows = await self.db.fetch(""" SELECT target_type, target_role, target_user, allow, deny FROM channel_overwrites @@ -298,18 +324,20 @@ class Storage: def _overwrite_convert(row): drow = dict(row) - drow['type'] = drow['target_type'] + + target_type = drow['target_type'] + drow['type'] = 'user' if target_type == 0 else 'role' # if type is 0, the overwrite is for a user # if type is 1, the overwrite is for a role drow['id'] = { 0: drow['target_user'], 1: drow['target_role'], - }[drow['type']] + }[target_type] drow['id'] = str(drow['id']) - drow.pop('overwrite_type') + drow.pop('target_type') drow.pop('target_user') drow.pop('target_role') @@ -335,8 +363,8 @@ class Storage: dbase['type'] = chan_type res = await self._channels_extra(dbase) - res['permission_overwrites'] = \ - list(await self._chan_overwrites(channel_id)) + res['permission_overwrites'] = await self.chan_overwrites( + channel_id) res['id'] = str(res['id']) return res @@ -401,8 +429,8 @@ class Storage: res = await self._channels_extra(drow) - res['permission_overwrites'] = \ - list(await self._chan_overwrites(row['id'])) + res['permission_overwrites'] = await self.chan_overwrites( + row['id']) # Making sure. res['id'] = str(res['id']) @@ -440,6 +468,7 @@ class Storage: permissions, managed, mentionable FROM roles WHERE guild_id = $1 + ORDER BY position ASC """, guild_id) return list(map(dict, roledata)) @@ -535,7 +564,70 @@ class Storage: return res - async def get_message(self, message_id: int) -> Dict: + async def get_reactions(self, message_id: int, user_id=None) -> List: + """Get all reactions in a message.""" + reactions = await self.db.fetch(""" + SELECT user_id, emoji_type, emoji_id, emoji_text + FROM message_reactions + WHERE message_id = $1 + ORDER BY react_ts + """, message_id) + + # ordered list of emoji + emoji = [] + + # the current state of emoji info + react_stats = {} + + # to generate the list, we pass through all + # all reactions and insert them all. + + # we can't use a set() because that + # doesn't guarantee any order. + for row in reactions: + etype = EmojiType(row['emoji_type']) + eid, etext = row['emoji_id'], row['emoji_text'] + + # get the main key to use, given + # the emoji information + _, main_emoji = emoji_sql(etype, eid, etext) + + if main_emoji in emoji: + continue + + # maintain order (first reacted comes first + # on the reaction list) + emoji.append(main_emoji) + + react_stats[main_emoji] = { + 'count': 0, + 'me': False, + 'emoji': partial_emoji(etype, eid, etext) + } + + # then the 2nd pass, where we insert + # the info for each reaction in the react_stats + # dictionary + for row in reactions: + etype = EmojiType(row['emoji_type']) + eid, etext = row['emoji_id'], row['emoji_text'] + + # same thing as the last loop, + # extracting main key + _, main_emoji = emoji_sql(etype, eid, etext) + + stats = react_stats[main_emoji] + stats['count'] += 1 + + if row['user_id'] == user_id: + stats['me'] = True + + # after processing reaction counts, + # we get them in the same order + # they were defined in the first loop. + return list(map(react_stats.get, emoji)) + + async def get_message(self, message_id: int, user_id=None) -> Dict: """Get a single message's payload.""" row = await self.db.fetchrow(""" SELECT id::text, channel_id::text, author_id, content, @@ -596,6 +688,8 @@ class Storage: res['mention_roles'] = await self._msg_regex( ROLE_MENTION, _get_role_mention, content) + res['reactions'] = await self.get_reactions(message_id, user_id) + # TODO: handle webhook authors res['author'] = await self.get_user(res['author_id']) res.pop('author_id') @@ -606,9 +700,6 @@ class Storage: # TODO: res['embeds'] res['embeds'] = [] - # TODO: res['reactions'] - res['reactions'] = [] - # TODO: res['pinned'] res['pinned'] = False @@ -966,7 +1057,6 @@ class Storage: """, user_id) for row in settings: - print(dict(row)) gid = int(row['guild_id']) drow = dict(row) diff --git a/litecord/types.py b/litecord/types.py new file mode 100644 index 0000000..4fddfff --- /dev/null +++ b/litecord/types.py @@ -0,0 +1,15 @@ + +class Color: + """Custom color class""" + def __init__(self, val: int): + self.blue = val & 255 + self.green = (val >> 8) & 255 + self.red = (val >> 16) & 255 + + @property + def value(self): + """Give the actual RGB integer encoding this color.""" + return int('%02x%02x%02x' % (self.red, self.green, self.blue), 16) + + def __int__(self): + return self.value diff --git a/litecord/utils.py b/litecord/utils.py index a350dad..2fda9d5 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -22,3 +22,18 @@ async def task_wrapper(name: str, coro): pass except: log.exception('{} task error', name) + + +def dict_get(mapping, key, default): + """Return `default` even when mapping[key] is None.""" + return mapping.get(key) or default + + +def index_by_func(function, indexable: iter) -> int: + """Search in an idexable and return the index number + for an iterm that has func(item) = True.""" + for index, item in enumerate(indexable): + if function(item): + return index + + return None diff --git a/manage.py b/manage.py new file mode 100755 index 0000000..a65b9d6 --- /dev/null +++ b/manage.py @@ -0,0 +1,12 @@ +#!/usr/bin/env python3 +import logging +import sys + +from manage.main import main + +import config + +logging.basicConfig(level=logging.DEBUG) + +if __name__ == '__main__': + sys.exit(main(config)) diff --git a/manage/__init__.py b/manage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/manage/cmd/migration/__init__.py b/manage/cmd/migration/__init__.py new file mode 100644 index 0000000..3a9fa59 --- /dev/null +++ b/manage/cmd/migration/__init__.py @@ -0,0 +1 @@ +from .command import setup as migration diff --git a/manage/cmd/migration/command.py b/manage/cmd/migration/command.py new file mode 100644 index 0000000..a1e87c6 --- /dev/null +++ b/manage/cmd/migration/command.py @@ -0,0 +1,143 @@ +import inspect +from pathlib import Path +from dataclasses import dataclass +from collections import namedtuple +from typing import Dict + +import asyncpg +from logbook import Logger + +log = Logger(__name__) + + +Migration = namedtuple('Migration', 'id name path') + + +@dataclass +class MigrationContext: + """Hold information about migration.""" + migration_folder: Path + scripts: Dict[int, Migration] + + @property + def latest(self): + """Return the latest migration ID.""" + return max(self.scripts.keys()) + + +def make_migration_ctx() -> MigrationContext: + """Create the MigrationContext instance.""" + # taken from https://stackoverflow.com/a/6628348 + script_path = inspect.stack()[0][1] + script_folder = '/'.join(script_path.split('/')[:-1]) + script_folder = Path(script_folder) + + migration_folder = script_folder / 'scripts' + + mctx = MigrationContext(migration_folder, {}) + + for mig_path in migration_folder.glob('*.sql'): + mig_path_str = str(mig_path) + + # extract migration script id and name + mig_filename = mig_path_str.split('/')[-1].split('.')[0] + name_fragments = mig_filename.split('_') + + mig_id = int(name_fragments[0]) + mig_name = '_'.join(name_fragments[1:]) + + mctx.scripts[mig_id] = Migration( + mig_id, mig_name, mig_path) + + return mctx + + +async def _ensure_changelog(app, ctx): + # make sure we have the migration table up + + try: + await app.db.execute(""" + CREATE TABLE migration_log ( + change_num bigint NOT NULL, + + apply_ts timestamp without time zone default + (now() at time zone 'utc'), + + description text, + + PRIMARY KEY (change_num) + ); + """) + + # if we were able to create the + # migration_log table, insert that we are + # on the latest version. + await app.db.execute(""" + INSERT INTO migration_log (change_num, description) + VALUES ($1, $2) + """, ctx.latest, 'migration setup') + except asyncpg.DuplicateTableError: + log.debug('existing migration table') + + +async def apply_migration(app, migration: Migration): + """Apply a single migration.""" + migration_sql = migration.path.read_text(encoding='utf-8') + + try: + await app.db.execute(""" + INSERT INTO migration_log (change_num, description) + VALUES ($1, $2) + """, migration.id, f'migration: {migration.name}') + except asyncpg.UniqueViolationError: + log.warning('already applied {}', migration.id) + return + + await app.db.execute(migration_sql) + log.info('applied {}', migration.id) + + +async def migrate_cmd(app, _args): + """Main migration command. + + This makes sure the database + is updated. + """ + + ctx = make_migration_ctx() + + await _ensure_changelog(app, ctx) + + # local point in the changelog + local_change = await app.db.fetchval(""" + SELECT max(change_num) + FROM migration_log + """) + + local_change = local_change or 0 + latest_change = ctx.latest + + log.debug('local: {}, latest: {}', local_change, latest_change) + + if local_change == latest_change: + print('no changes to do, exiting') + return + + # we do local_change + 1 so we start from the + # next migration to do, end in latest_change + 1 + # because of how range() works. + for idx in range(local_change + 1, latest_change + 1): + migration = ctx.scripts.get(idx) + + print('applying', migration.id, migration.name) + await apply_migration(app, migration) + + +def setup(subparser): + migrate_parser = subparser.add_parser( + 'migrate', + help='Run migration tasks', + description=migrate_cmd.__doc__ + ) + + migrate_parser.set_defaults(func=migrate_cmd) diff --git a/manage/cmd/migration/scripts/1_message_embed_type.sql b/manage/cmd/migration/scripts/1_message_embed_type.sql new file mode 100644 index 0000000..8650558 --- /dev/null +++ b/manage/cmd/migration/scripts/1_message_embed_type.sql @@ -0,0 +1,6 @@ +-- unused tables +DROP TABLE message_embeds; +DROP TABLE embeds; + +ALTER TABLE messages + ADD COLUMN embeds jsonb DEFAULT '[]' diff --git a/manage/main.py b/manage/main.py new file mode 100644 index 0000000..7cb0f99 --- /dev/null +++ b/manage/main.py @@ -0,0 +1,58 @@ +import asyncio +import argparse +from sys import argv +from dataclasses import dataclass + +from logbook import Logger + +from run import init_app_managers, init_app_db +from manage.cmd.migration import migration + +log = Logger(__name__) + + +@dataclass +class FakeApp: + """Fake app instance.""" + config: dict + db = None + loop: asyncio.BaseEventLoop = None + ratelimiter = None + state_manager = None + storage = None + dispatcher = None + presence = None + + +def init_parser(): + parser = argparse.ArgumentParser() + subparser = parser.add_subparsers(help='operations') + + migration(subparser) + + return parser + + +def main(config): + """Start the script""" + loop = asyncio.get_event_loop() + cfg = getattr(config, config.MODE) + app = FakeApp(cfg.__dict__) + + loop.run_until_complete(init_app_db(app)) + init_app_managers(app) + + # initialize argparser + parser = init_parser() + + try: + if len(argv) < 2: + parser.print_help() + return + + args = parser.parse_args() + loop.run_until_complete(args.func(app, args)) + except Exception: + log.exception('error while running command') + finally: + loop.run_until_complete(app.db.close()) diff --git a/nginx.conf b/nginx.conf index 9590d5c..d42d2df 100644 --- a/nginx.conf +++ b/nginx.conf @@ -5,13 +5,11 @@ server { location / { proxy_pass http://localhost:5000; } -} -# Main litecord websocket proxy. -server { - server_name websocket.somewhere; - - location / { + # if you don't want to keep the gateway + # domain as the main domain, you can + # keep a separate server block + location /ws { proxy_pass http://localhost:5001; # those options are required for websockets diff --git a/run.py b/run.py index 98f674d..c1af61f 100644 --- a/run.py +++ b/run.py @@ -9,9 +9,28 @@ from quart import Quart, g, jsonify, request from logbook import StreamHandler, Logger from logbook.compat import redirect_logging +# import the config set by instance owner import config -from litecord.blueprints import gateway, auth, users, guilds, channels, \ - webhooks, science, voice, invites, relationships, dms + +from litecord.blueprints import ( + gateway, auth, users, guilds, channels, webhooks, science, + voice, invites, relationships, dms +) + +# those blueprints are separated from the "main" ones +# for code readability if people want to dig through +# the codebase. +from litecord.blueprints.guild import ( + guild_roles, guild_members, guild_channels, guild_mod +) + +from litecord.blueprints.channel import ( + channel_messages, channel_reactions, channel_pins +) + +from litecord.ratelimits.handler import ratelimit_handler +from litecord.ratelimits.main import RatelimitManager + from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager @@ -50,8 +69,18 @@ bps = { auth: '/auth', users: '/users', relationships: '/users', + guilds: '/guilds', + guild_roles: '/guilds', + guild_members: '/guilds', + guild_channels: '/guilds', + guild_mod: '/guilds', + channels: '/channels', + channel_messages: '/channels', + channel_reactions: '/channels', + channel_pins: '/channels', + webhooks: None, science: None, voice: '/voice', @@ -64,6 +93,11 @@ for bp, suffix in bps.items(): app.register_blueprint(bp, url_prefix=f'/api/v6{suffix}') +@app.before_request +async def app_before_request(): + await ratelimit_handler() + + @app.after_request async def app_after_request(resp): origin = request.headers.get('Origin', '*') @@ -80,19 +114,44 @@ async def app_after_request(resp): # resp.headers['Access-Control-Allow-Methods'] = '*' resp.headers['Access-Control-Allow-Methods'] = \ resp.headers.get('allow', '*') + return resp -@app.before_serving -async def app_before_serving(): - log.info('opening db') +@app.after_request +async def app_set_ratelimit_headers(resp): + """Set the specific ratelimit headers.""" + try: + bucket = request.bucket + + if bucket is None: + raise AttributeError() + + resp.headers['X-RateLimit-Limit'] = str(bucket.requests) + resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens) + resp.headers['X-RateLimit-Reset'] = str(bucket._window + bucket.second) + + resp.headers['X-RateLimit-Global'] = str(request.bucket_global).lower() + + # only add Retry-After if we actually hit a ratelimit + retry_after = request.retry_after + if request.retry_after: + resp.headers['Retry-After'] = str(retry_after) + except AttributeError: + pass + + return resp + + +async def init_app_db(app): + """Connect to databases""" app.db = await asyncpg.create_pool(**app.config['POSTGRES']) - g.app = app +def init_app_managers(app): + """Initialize singleton classes.""" app.loop = asyncio.get_event_loop() - g.loop = asyncio.get_event_loop() - + app.ratelimiter = RatelimitManager() app.state_manager = StateManager() app.storage = Storage(app.db) @@ -101,6 +160,17 @@ async def app_before_serving(): app.state_manager, app.dispatcher) app.storage.presence = app.presence + +@app.before_serving +async def app_before_serving(): + log.info('opening db') + await init_app_db(app) + + g.app = app + g.loop = asyncio.get_event_loop() + + init_app_managers(app) + # start the websocket, etc host, port = app.config['WS_HOST'], app.config['WS_PORT'] log.info(f'starting websocket at {host} {port}') @@ -108,8 +178,11 @@ async def app_before_serving(): async def _wrapper(ws, url): # We wrap the main websocket_handler # so we can pass quart's app object. + + # TODO: pass just the app object await websocket_handler((app.db, app.state_manager, app.storage, - app.loop, app.dispatcher, app.presence), + app.loop, app.dispatcher, app.presence, + app.ratelimiter), ws, url) ws_future = websockets.serve(_wrapper, host, port) @@ -119,6 +192,15 @@ async def app_before_serving(): @app.after_serving async def app_after_serving(): + """Shutdown tasks for the server.""" + + # first close all clients, then close db + tasks = app.state_manager.gen_close_tasks() + if tasks: + await asyncio.wait(tasks, loop=app.loop) + + app.state_manager.close() + log.info('closing db') await app.db.close() @@ -130,9 +212,13 @@ async def handle_litecord_err(err): except IndexError: ejson = {} + try: + ejson['code'] = err.error_code + except AttributeError: + pass + return jsonify({ 'error': True, - # 'code': err.code, 'status': err.status_code, 'message': err.message, **ejson diff --git a/schema.sql b/schema.sql index e43ba45..2d3d5c6 100644 --- a/schema.sql +++ b/schema.sql @@ -75,6 +75,9 @@ CREATE TABLE IF NOT EXISTS users ( phone varchar(60) DEFAULT '', password_hash text NOT NULL, + -- store the last time the user logged in via the gateway + last_session timestamp without time zone default (now() at time zone 'utc'), + PRIMARY KEY (id, username, discriminator) ); @@ -131,6 +134,10 @@ CREATE TABLE IF NOT EXISTS user_settings ( -- appearance message_display_compact bool DEFAULT false, + + -- for now we store status but don't + -- actively use it, since the official client + -- sends its own presence on IDENTIFY status text DEFAULT 'online' NOT NULL, theme text DEFAULT 'dark' NOT NULL, developer_mode bool DEFAULT true, @@ -328,7 +335,7 @@ CREATE TABLE IF NOT EXISTS channel_overwrites ( channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, -- target_type = 0 -> use target_user - -- target_type = 1 -> user target_role + -- target_type = 1 -> use target_role -- discord already has overwrite.type = 'role' | 'member' -- so this allows us to be more compliant with the API target_type integer default null, @@ -344,11 +351,15 @@ CREATE TABLE IF NOT EXISTS channel_overwrites ( -- they're bigints (64bits), discord, -- for now, only needs 53. allow bigint DEFAULT 0, - deny bigint DEFAULT 0, - - PRIMARY KEY (channel_id, target_role, target_user) + deny bigint DEFAULT 0 ); +-- columns in private keys can't have NULL values, +-- so instead we use a custom constraint with UNIQUE + +ALTER TABLE channel_overwrites ADD CONSTRAINT channel_overwrites_uniq + UNIQUE (channel_id, target_role, target_user); + CREATE TABLE IF NOT EXISTS features ( id serial PRIMARY KEY, @@ -479,11 +490,6 @@ CREATE TABLE IF NOT EXISTS bans ( ); -CREATE TABLE IF NOT EXISTS embeds ( - -- TODO: this table - id bigint PRIMARY KEY -); - CREATE TABLE IF NOT EXISTS messages ( id bigint PRIMARY KEY, channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, @@ -504,6 +510,8 @@ CREATE TABLE IF NOT EXISTS messages ( tts bool default false, mention_everyone bool default false, + embeds jsonb DEFAULT '[]', + nonce bigint default 0, message_type int NOT NULL @@ -515,22 +523,22 @@ CREATE TABLE IF NOT EXISTS message_attachments ( PRIMARY KEY (message_id, attachment) ); -CREATE TABLE IF NOT EXISTS message_embeds ( - message_id bigint REFERENCES messages (id) UNIQUE, - embed_id bigint REFERENCES embeds (id), - PRIMARY KEY (message_id, embed_id) -); - CREATE TABLE IF NOT EXISTS message_reactions ( message_id bigint REFERENCES messages (id), user_id bigint REFERENCES users (id), - -- since it can be a custom emote, or unicode emoji + react_ts timestamp without time zone default (now() at time zone 'utc'), + + -- emoji_type = 0 -> custom emoji + -- emoji_type = 1 -> unicode emoji + emoji_type int DEFAULT 0, emoji_id bigint REFERENCES guild_emoji (id), - emoji_text text NOT NULL, - PRIMARY KEY (message_id, user_id, emoji_id, emoji_text) + emoji_text text ); +ALTER TABLE message_reactions ADD CONSTRAINT message_reactions_main_uniq + UNIQUE (message_id, user_id, emoji_id, emoji_text); + CREATE TABLE IF NOT EXISTS channel_pins ( channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, message_id bigint REFERENCES messages (id) ON DELETE CASCADE,