From 0f7ffaf717d6feb264621d00512dc16169476355 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Fri, 2 Nov 2018 22:07:32 -0300 Subject: [PATCH] storage: add Storage.get_reactions This finishes basic reaction code (both inserting and putting a reaction). SQL for instances: ```sql DROP TABLE message_reactions; ``` Then rerun `schema.sql` - channel.reactions: fix partial_emoji - schema.sql: add message_reactions.react_ts and unique constraint instead of primary key --- litecord/blueprints/channel/messages.py | 8 +-- litecord/blueprints/channel/reactions.py | 17 +++--- litecord/storage.py | 75 ++++++++++++++++++++++-- schema.sql | 9 ++- 4 files changed, 90 insertions(+), 19 deletions(-) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 231f83d..1e54fdf 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -81,7 +81,7 @@ async def get_messages(channel_id): result = [] for message_id in message_ids: - msg = await app.storage.get_message(message_id['id']) + msg = await app.storage.get_message(message_id['id'], user_id) if msg is None: continue @@ -98,7 +98,7 @@ async def get_single_message(channel_id, message_id): await channel_check(user_id, channel_id) # TODO: check READ_MESSAGE_HISTORY permissions - message = await app.storage.get_message(message_id) + message = await app.storage.get_message(message_id, user_id) if not message: raise MessageNotFound() @@ -168,7 +168,7 @@ async def create_message(channel_id): MessageType.DEFAULT.value ) - payload = await app.storage.get_message(message_id) + payload = await app.storage.get_message(message_id, user_id) if ctype == ChannelType.DM: # guild id here is the peer's ID. @@ -218,7 +218,7 @@ async def edit_message(channel_id, message_id): # TODO: update embed - message = await app.storage.get_message(message_id) + message = await app.storage.get_message(message_id, user_id) # only dispatch MESSAGE_UPDATE if we actually had any update to start with if updated: diff --git a/litecord/blueprints/channel/reactions.py b/litecord/blueprints/channel/reactions.py index 2c01977..6db9e6b 100644 --- a/litecord/blueprints/channel/reactions.py +++ b/litecord/blueprints/channel/reactions.py @@ -47,10 +47,11 @@ def emoji_info_from_str(emoji: str) -> tuple: return emoji_type, emoji_id, emoji_name -def _partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: +def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict: + print(emoji_type, emoji_id, emoji_name) return { - 'id': None if emoji_type.UNICODE else emoji_id, - 'name': emoji_id if emoji_type.UNICODE else emoji_name + 'id': None if emoji_type == EmojiType.UNICODE else emoji_id, + 'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id } @@ -88,7 +89,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str): emoji_id if emoji_type == EmojiType.UNICODE else None ) - partial = _partial_emoji(emoji_type, emoji_id, emoji_name) + partial = partial_emoji(emoji_type, emoji_id, emoji_name) payload = _make_payload(user_id, channel_id, message_id, partial) if ctype in GUILD_CHANS: @@ -100,7 +101,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str): return '', 204 -def _emoji_sql(emoji_type, emoji_id, emoji_name, param=4): +def emoji_sql(emoji_type, emoji_id, emoji_name, param=4): """Extract SQL clauses to search for specific emoji in the message_reactions table.""" param = f'${param}' @@ -120,7 +121,7 @@ def _emoji_sql_simple(emoji: str, param=4): """Simpler version of _emoji_sql for functions that don't need the results from emoji_info_from_str.""" emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) - return _emoji_sql(emoji_type, emoji_id, emoji_name, param) + return emoji_sql(emoji_type, emoji_id, emoji_name, param) async def remove_reaction(channel_id: int, message_id: int, @@ -128,7 +129,7 @@ async def remove_reaction(channel_id: int, message_id: int, ctype, guild_id = await channel_check(user_id, channel_id) emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji) - where_ext, main_emoji = _emoji_sql(emoji_type, emoji_id, emoji_name) + where_ext, main_emoji = emoji_sql(emoji_type, emoji_id, emoji_name) await app.db.execute( f""" @@ -139,7 +140,7 @@ async def remove_reaction(channel_id: int, message_id: int, {where_ext} """, message_id, user_id, emoji_type, main_emoji) - partial = _partial_emoji(emoji_type, emoji_id, emoji_name) + partial = partial_emoji(emoji_type, emoji_id, emoji_name) payload = _make_payload(user_id, channel_id, message_id, partial) if ctype in GUILD_CHANS: diff --git a/litecord/storage.py b/litecord/storage.py index 2de2b99..313b4f2 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -5,6 +5,9 @@ from logbook import Logger from .enums import ChannelType, RelationshipType from .schemas import USER_MENTION, ROLE_MENTION +from litecord.blueprints.channel.reactions import ( + emoji_info_from_str, EmojiType, emoji_sql, partial_emoji +) log = Logger(__name__) @@ -553,7 +556,72 @@ class Storage: return res - async def get_message(self, message_id: int) -> Dict: + async def get_reactions(self, message_id: int, user_id=None) -> List: + """Get all reactions in a message.""" + reactions = await self.db.fetch(""" + SELECT user_id, emoji_type, emoji_id, emoji_text + FROM message_reactions + ORDER BY react_ts + """) + + # ordered list of emoji + emoji = [] + + # the current state of emoji info + react_stats = {} + + # to generate the list, we pass through all + # all reactions and insert them all. + + # we can't use a set() because that + # doesn't guarantee any order. + for row in reactions: + etype = EmojiType(row['emoji_type']) + eid, etext = row['emoji_id'], row['emoji_text'] + + # get the main key to use, given + # the emoji information + _, main_emoji = emoji_sql(etype, eid, etext) + + if main_emoji in emoji: + continue + + # maintain order (first reacted comes first + # on the reaction list) + emoji.append(main_emoji) + + react_stats[main_emoji] = { + 'count': 0, + 'me': False, + 'emoji': partial_emoji(etype, eid, etext) + } + + # then the 2nd pass, where we insert + # the info for each reaction in the react_stats + # dictionary + for row in reactions: + etype = EmojiType(row['emoji_type']) + eid, etext = row['emoji_id'], row['emoji_text'] + + # same thing as the last loop, + # extracting main key + _, main_emoji = emoji_sql(etype, eid, etext) + + stats = react_stats[main_emoji] + stats['count'] += 1 + + print(row['user_id'], user_id) + if row['user_id'] == user_id: + stats['me'] = True + + # after processing reaction counts, + # we get them in the same order + # they were defined in the first loop. + print(emoji) + print(react_stats) + return list(map(react_stats.get, emoji)) + + async def get_message(self, message_id: int, user_id=None) -> Dict: """Get a single message's payload.""" row = await self.db.fetchrow(""" SELECT id::text, channel_id::text, author_id, content, @@ -614,6 +682,8 @@ class Storage: res['mention_roles'] = await self._msg_regex( ROLE_MENTION, _get_role_mention, content) + res['reactions'] = await self.get_reactions(message_id, user_id) + # TODO: handle webhook authors res['author'] = await self.get_user(res['author_id']) res.pop('author_id') @@ -624,9 +694,6 @@ class Storage: # TODO: res['embeds'] res['embeds'] = [] - # TODO: res['reactions'] - res['reactions'] = [] - # TODO: res['pinned'] res['pinned'] = False diff --git a/schema.sql b/schema.sql index e8d101c..3a654c3 100644 --- a/schema.sql +++ b/schema.sql @@ -528,15 +528,18 @@ CREATE TABLE IF NOT EXISTS message_reactions ( message_id bigint REFERENCES messages (id), user_id bigint REFERENCES users (id), + react_ts timestamp without time zone default (now() at time zone 'utc'), + -- emoji_type = 0 -> custom emoji -- emoji_type = 1 -> unicode emoji emoji_type int DEFAULT 0, emoji_id bigint REFERENCES guild_emoji (id), - emoji_text text, - - PRIMARY KEY (message_id, user_id, emoji_id, emoji_text) + emoji_text text ); +ALTER TABLE message_reactions ADD CONSTRAINT message_reactions_main_uniq + UNIQUE (message_id, user_id, emoji_id, emoji_text); + CREATE TABLE IF NOT EXISTS channel_pins ( channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, message_id bigint REFERENCES messages (id) ON DELETE CASCADE,