From 2b1f9489b73b023739ba0e5a0337e83c1d0cbb7f Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 2 Nov 2018 18:29:07 -0300 Subject: [PATCH] channel: add reactions blueprint SQL for instances: ```sql DROP TABLE message_reactions; ``` Then rerun `schema.sql`. --- litecord/blueprints/channel/messages.py | 20 +- litecord/blueprints/channel/reactions.py | 224 +++++++++++++++++++++++ nginx.conf | 10 +- schema.sql | 7 +- 4 files changed, 246 insertions(+), 15 deletions(-) create mode 100644 litecord/blueprints/channel/reactions.py diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 4663693..7ea70fb 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -16,6 +16,18 @@ log = Logger(__name__) bp = Blueprint('channel_messages', __name__) +def extract_limit(request, default: int = 50): + try: + limit = int(request.args.get('limit', 50)) + + if limit not in range(0, 100): + raise ValueError() + except (TypeError, ValueError): + raise BadRequest('limit not int') + + return limit + + def query_tuple_from_args(args: dict, limit: int) -> tuple: before, after = None, None @@ -41,13 +53,7 @@ async def get_messages(channel_id): # TODO: check READ_MESSAGE_HISTORY permission await channel_check(user_id, channel_id) - try: - limit = int(request.args.get('limit', 50)) - - if limit not in range(0, 100): - raise ValueError() - except (TypeError, ValueError): - raise BadRequest('limit not int') + limit = extract_limit(request, 50) where_clause = '' before, after = query_tuple_from_args(request.args, limit) diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py new file mode 100644 index 0000000..2c01977 --- /dev/null +++ b/litecord/blueprints/channel/reactions.py @@ -0,0 +1,224 @@ +from enum import IntEnum + +from quart import Blueprint, request, current_app as app, jsonify +from logbook import Logger + + +from litecord.utils import async_map +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import channel_check +from litecord.blueprints.channel.messages import ( + query_tuple_from_args, extract_limit +) + +from litecord.errors import MessageNotFound, Forbidden, BadRequest +from litecord.enums import GUILD_CHANS + + +log = Logger(__name__) +bp = Blueprint('channel_reactions', __name__) + +BASEPATH = '//messages//reactions' + + +class EmojiType(IntEnum): + CUSTOM = 0 + UNICODE = 1 + + +def emoji_info_from_str(emoji: str) -> tuple: + """Extract emoji information from an emoji string + given on the reaction endpoints.""" + # custom emoji have an emoji of name:id + # unicode emoji just have the raw unicode. + + # try checking if the emoji is custom or unicode + emoji_type = 0 if ':' in emoji else 1 + emoji_type = EmojiType(emoji_type) + + # extract the emoji id OR the unicode value of the emoji + # depending if it is custom or not + emoji_id = (int(emoji.split(':')[1]) + if emoji_type == EmojiType.CUSTOM + else emoji) + + emoji_name = emoji.split(':')[0] + + return emoji_type, emoji_id, emoji_name + + +def _partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: + return { + 'id': None if emoji_type.UNICODE else emoji_id, + 'name': emoji_id if emoji_type.UNICODE else emoji_name + } + + +def _make_payload(user_id, channel_id, message_id, partial): + return { + 'user_id': str(user_id), + 'channel_id': str(channel_id), + 'message_id': str(message_id), + 'emoji': partial + } + + +@bp.route(f'{BASEPATH}//@me', methods=['PUT']) +async def add_reaction(channel_id: int, message_id: int, emoji: str): + """Put a reaction.""" + user_id = await token_check() + + # TODO: check READ_MESSAGE_HISTORY permission + # and ADD_REACTIONS. look on route docs. + ctype, guild_id = await channel_check(user_id, channel_id) + + emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) + + await app.db.execute( + """ + INSERT INTO message_reactions (message_id, user_id, + emoji_type, emoji_id, emoji_text) + VALUES ($1, $2, $3, $4, $5) + """, message_id, user_id, emoji_type, + + # if it is custom, we put the emoji_id on emoji_id + # column, if it isn't, we put it on emoji_text + # column. + emoji_id if emoji_type == EmojiType.CUSTOM else None, + emoji_id if emoji_type == EmojiType.UNICODE else None + ) + + partial = _partial_emoji(emoji_type, emoji_id, emoji_name) + payload = _make_payload(user_id, channel_id, message_id, partial) + + if ctype in GUILD_CHANS: + payload['guild_id'] = str(guild_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'MESSAGE_REACTION_ADD', payload) + + return '', 204 + + +def _emoji_sql(emoji_type, emoji_id, emoji_name, param=4): + """Extract SQL clauses to search for specific emoji + in the message_reactions table.""" + param = f'${param}' + + # know which column to filter with + where_ext = (f'AND emoji_id = {param}' + if emoji_type == EmojiType.CUSTOM else + f'AND emoji_text = {param}') + + # which emoji to remove (custom or unicode) + main_emoji = emoji_id if emoji_type == EmojiType.CUSTOM else emoji_name + + return where_ext, main_emoji + + +def _emoji_sql_simple(emoji: str, param=4): + """Simpler version of _emoji_sql for functions that + don't need the results from emoji_info_from_str.""" + emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) + return _emoji_sql(emoji_type, emoji_id, emoji_name, param) + + +async def remove_reaction(channel_id: int, message_id: int, + user_id: int, emoji: str): + ctype, guild_id = await channel_check(user_id, channel_id) + + emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) + where_ext, main_emoji = _emoji_sql(emoji_type, emoji_id, emoji_name) + + await app.db.execute( + f""" + DELETE FROM message_reactions + WHERE message_id = $1 + AND user_id = $2 + AND emoji_type = $3 + {where_ext} + """, message_id, user_id, emoji_type, main_emoji) + + partial = _partial_emoji(emoji_type, emoji_id, emoji_name) + payload = _make_payload(user_id, channel_id, message_id, partial) + + if ctype in GUILD_CHANS: + payload['guild_id'] = str(guild_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'MESSAGE_REACTION_REMOVE', payload) + + +@bp.route(f'{BASEPATH}//@me', methods=['DELETE']) +async def remove_own_reaction(channel_id, message_id, emoji): + """Remove a reaction.""" + user_id = await token_check() + + await remove_reaction(channel_id, message_id, user_id, emoji) + + return '', 204 + + +@bp.route(f'{BASEPATH}//', methods=['DELETE']) +async def remove_user_reaction(channel_id, message_id, emoji, other_id): + """Remove a reaction made by another user.""" + await token_check() + + # TODO: check MANAGE_MESSAGES permission (and use user_id + # from token_check to do it) + await remove_reaction(channel_id, message_id, other_id, emoji) + + return '', 204 + + +@bp.route(f'{BASEPATH}/', methods=['GET']) +async def list_users_reaction(channel_id, message_id, emoji): + """Get the list of all users who reacted with a certain emoji.""" + user_id = await token_check() + + # this is not using either ctype or guild_id + # that are returned by channel_check + await channel_check(user_id, channel_id) + + limit = extract_limit(request, 25) + before, after = query_tuple_from_args(request.args, limit) + + before_clause = 'AND user_id < $2' if before else '' + after_clause = 'AND user_id > $3' if after else '' + + where_ext, main_emoji = _emoji_sql_simple(emoji, 4) + + rows = await app.db.fetch(f""" + SELECT user_id + FROM message_reactions + WHERE message_id = $1 {before_clause} {after_clause} {where_ext} + """, message_id, before, after, main_emoji) + + user_ids = [r['user_id'] for r in rows] + users = await async_map(app.storage.get_user, user_ids) + return jsonify(users) + + +@bp.route(f'{BASEPATH}', methods=['DELETE']) +async def remove_all_reactions(channel_id, message_id): + """Remove all reactions in a message.""" + user_id = await token_check() + + # TODO: check MANAGE_MESSAGES permission + ctype, guild_id = await channel_check(user_id, channel_id) + + await app.db.execute(""" + DELETE FROM message_reactions + WHERE message_id = $1 + """, message_id) + + payload = { + 'channel_id': str(channel_id), + 'message_id': str(message_id), + } + + if ctype in GUILD_CHANS: + payload['guild_id'] = str(guild_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'MESSAGE_REACTION_REMOVE_ALL', payload) diff --git a/nginx.conf b/nginx.conf index 9590d5c..d42d2df 100644 --- a/nginx.conf +++ b/nginx.conf @@ -5,13 +5,11 @@ server { location / { proxy_pass http://localhost:5000; } -} -# Main litecord websocket proxy. -server { - server_name websocket.somewhere; - - location / { + # if you don't want to keep the gateway + # domain as the main domain, you can + # keep a separate server block + location /ws { proxy_pass http://localhost:5001; # those options are required for websockets diff --git a/schema.sql b/schema.sql index b520d18..e8d101c 100644 --- a/schema.sql +++ b/schema.sql @@ -528,9 +528,12 @@ CREATE TABLE IF NOT EXISTS message_reactions ( message_id bigint REFERENCES messages (id), user_id bigint REFERENCES users (id), - -- since it can be a custom emote, or unicode emoji + -- emoji_type = 0 -> custom emoji + -- emoji_type = 1 -> unicode emoji + emoji_type int DEFAULT 0, emoji_id bigint REFERENCES guild_emoji (id), - emoji_text text NOT NULL, + emoji_text text, + PRIMARY KEY (message_id, user_id, emoji_id, emoji_text) );