storage: add Storage.get_reactions

This finishes basic reaction code (both inserting and putting a reaction).

SQL for instances:
```sql
DROP TABLE message_reactions;
```

Then rerun `schema.sql`

 - channel.reactions: fix partial_emoji
 - schema.sql: add message_reactions.react_ts and unique constraint
    instead of primary key
This commit is contained in:
Luna Mendes 2018-11-02 22:07:32 -03:00
parent db7fbdb954
commit 0f7ffaf717
4 changed files with 90 additions and 19 deletions

View File

@ -81,7 +81,7 @@ async def get_messages(channel_id):
result = []
for message_id in message_ids:
msg = await app.storage.get_message(message_id['id'])
msg = await app.storage.get_message(message_id['id'], user_id)
if msg is None:
continue
@ -98,7 +98,7 @@ async def get_single_message(channel_id, message_id):
await channel_check(user_id, channel_id)
# TODO: check READ_MESSAGE_HISTORY permissions
message = await app.storage.get_message(message_id)
message = await app.storage.get_message(message_id, user_id)
if not message:
raise MessageNotFound()
@ -168,7 +168,7 @@ async def create_message(channel_id):
MessageType.DEFAULT.value
)
payload = await app.storage.get_message(message_id)
payload = await app.storage.get_message(message_id, user_id)
if ctype == ChannelType.DM:
# guild id here is the peer's ID.
@ -218,7 +218,7 @@ async def edit_message(channel_id, message_id):
# TODO: update embed
message = await app.storage.get_message(message_id)
message = await app.storage.get_message(message_id, user_id)
# only dispatch MESSAGE_UPDATE if we actually had any update to start with
if updated:

View File

@ -47,10 +47,11 @@ def emoji_info_from_str(emoji: str) -> tuple:
return emoji_type, emoji_id, emoji_name
def _partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
def partial_emoji(emoji_type, emoji_id, emoji_name) -> dict:
print(emoji_type, emoji_id, emoji_name)
return {
'id': None if emoji_type.UNICODE else emoji_id,
'name': emoji_id if emoji_type.UNICODE else emoji_name
'id': None if emoji_type == EmojiType.UNICODE else emoji_id,
'name': emoji_name if emoji_type == EmojiType.UNICODE else emoji_id
}
@ -88,7 +89,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
emoji_id if emoji_type == EmojiType.UNICODE else None
)
partial = _partial_emoji(emoji_type, emoji_id, emoji_name)
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
payload = _make_payload(user_id, channel_id, message_id, partial)
if ctype in GUILD_CHANS:
@ -100,7 +101,7 @@ async def add_reaction(channel_id: int, message_id: int, emoji: str):
return '', 204
def _emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
def emoji_sql(emoji_type, emoji_id, emoji_name, param=4):
"""Extract SQL clauses to search for specific emoji
in the message_reactions table."""
param = f'${param}'
@ -120,7 +121,7 @@ def _emoji_sql_simple(emoji: str, param=4):
"""Simpler version of _emoji_sql for functions that
don't need the results from emoji_info_from_str."""
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
return _emoji_sql(emoji_type, emoji_id, emoji_name, param)
return emoji_sql(emoji_type, emoji_id, emoji_name, param)
async def remove_reaction(channel_id: int, message_id: int,
@ -128,7 +129,7 @@ async def remove_reaction(channel_id: int, message_id: int,
ctype, guild_id = await channel_check(user_id, channel_id)
emoji_type, emoji_id, emoji_name = emoji_info_from_str(emoji)
where_ext, main_emoji = _emoji_sql(emoji_type, emoji_id, emoji_name)
where_ext, main_emoji = emoji_sql(emoji_type, emoji_id, emoji_name)
await app.db.execute(
f"""
@ -139,7 +140,7 @@ async def remove_reaction(channel_id: int, message_id: int,
{where_ext}
""", message_id, user_id, emoji_type, main_emoji)
partial = _partial_emoji(emoji_type, emoji_id, emoji_name)
partial = partial_emoji(emoji_type, emoji_id, emoji_name)
payload = _make_payload(user_id, channel_id, message_id, partial)
if ctype in GUILD_CHANS:

View File

@ -5,6 +5,9 @@ from logbook import Logger
from .enums import ChannelType, RelationshipType
from .schemas import USER_MENTION, ROLE_MENTION
from litecord.blueprints.channel.reactions import (
emoji_info_from_str, EmojiType, emoji_sql, partial_emoji
)
log = Logger(__name__)
@ -553,7 +556,72 @@ class Storage:
return res
async def get_message(self, message_id: int) -> Dict:
async def get_reactions(self, message_id: int, user_id=None) -> List:
"""Get all reactions in a message."""
reactions = await self.db.fetch("""
SELECT user_id, emoji_type, emoji_id, emoji_text
FROM message_reactions
ORDER BY react_ts
""")
# ordered list of emoji
emoji = []
# the current state of emoji info
react_stats = {}
# to generate the list, we pass through all
# all reactions and insert them all.
# we can't use a set() because that
# doesn't guarantee any order.
for row in reactions:
etype = EmojiType(row['emoji_type'])
eid, etext = row['emoji_id'], row['emoji_text']
# get the main key to use, given
# the emoji information
_, main_emoji = emoji_sql(etype, eid, etext)
if main_emoji in emoji:
continue
# maintain order (first reacted comes first
# on the reaction list)
emoji.append(main_emoji)
react_stats[main_emoji] = {
'count': 0,
'me': False,
'emoji': partial_emoji(etype, eid, etext)
}
# then the 2nd pass, where we insert
# the info for each reaction in the react_stats
# dictionary
for row in reactions:
etype = EmojiType(row['emoji_type'])
eid, etext = row['emoji_id'], row['emoji_text']
# same thing as the last loop,
# extracting main key
_, main_emoji = emoji_sql(etype, eid, etext)
stats = react_stats[main_emoji]
stats['count'] += 1
print(row['user_id'], user_id)
if row['user_id'] == user_id:
stats['me'] = True
# after processing reaction counts,
# we get them in the same order
# they were defined in the first loop.
print(emoji)
print(react_stats)
return list(map(react_stats.get, emoji))
async def get_message(self, message_id: int, user_id=None) -> Dict:
"""Get a single message's payload."""
row = await self.db.fetchrow("""
SELECT id::text, channel_id::text, author_id, content,
@ -614,6 +682,8 @@ class Storage:
res['mention_roles'] = await self._msg_regex(
ROLE_MENTION, _get_role_mention, content)
res['reactions'] = await self.get_reactions(message_id, user_id)
# TODO: handle webhook authors
res['author'] = await self.get_user(res['author_id'])
res.pop('author_id')
@ -624,9 +694,6 @@ class Storage:
# TODO: res['embeds']
res['embeds'] = []
# TODO: res['reactions']
res['reactions'] = []
# TODO: res['pinned']
res['pinned'] = False

View File

@ -528,15 +528,18 @@ CREATE TABLE IF NOT EXISTS message_reactions (
message_id bigint REFERENCES messages (id),
user_id bigint REFERENCES users (id),
react_ts timestamp without time zone default (now() at time zone 'utc'),
-- emoji_type = 0 -> custom emoji
-- emoji_type = 1 -> unicode emoji
emoji_type int DEFAULT 0,
emoji_id bigint REFERENCES guild_emoji (id),
emoji_text text,
PRIMARY KEY (message_id, user_id, emoji_id, emoji_text)
emoji_text text
);
ALTER TABLE message_reactions ADD CONSTRAINT message_reactions_main_uniq
UNIQUE (message_id, user_id, emoji_id, emoji_text);
CREATE TABLE IF NOT EXISTS channel_pins (
channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
message_id bigint REFERENCES messages (id) ON DELETE CASCADE,