From 94fe51ac69c38da665b414d93f1a478a71239f17 Mon Sep 17 00:00:00 2001 From: Luna Date: Tue, 4 Dec 2018 18:37:42 -0300 Subject: [PATCH] channel.messages: add embed insertion - storage: move pg_set_json to litecord.utils to fix circular imports - storage: add embed fetch to get_message - embed.schemas: fix author.url's url - schemas: add EMBED_OBJECT to MESSAGE_CREATE --- litecord/blueprints/channel/messages.py | 53 ++++++++++++++----------- litecord/embed/schemas.py | 8 +--- litecord/schemas.py | 7 ++++ litecord/storage.py | 30 +++----------- litecord/utils.py | 19 +++++++++ 5 files changed, 63 insertions(+), 54 deletions(-) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 4bd0206..cf5d552 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -10,6 +10,7 @@ from litecord.errors import MessageNotFound, Forbidden, BadRequest from litecord.enums import MessageType, ChannelType, GUILD_CHANS from litecord.snowflake import get_snowflake from litecord.schemas import validate, MESSAGE_CREATE +from litecord.utils import pg_set_json log = Logger(__name__) @@ -143,24 +144,29 @@ async def create_message(channel_id: int, actual_guild_id: int, author_id: int, data: dict) -> int: message_id = get_snowflake() - await app.db.execute( - """ - INSERT INTO messages (id, channel_id, guild_id, author_id, - content, tts, mention_everyone, nonce, message_type) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """, - message_id, - channel_id, - actual_guild_id, - author_id, - data['content'], + async with app.db.acquire() as conn: + await pg_set_json(conn) - data['tts'], - data['everyone_mention'], + await conn.execute( + """ + INSERT INTO messages (id, channel_id, guild_id, author_id, + content, tts, mention_everyone, nonce, message_type, + embeds) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + """, + message_id, + channel_id, + actual_guild_id, + author_id, + data['content'], - data['nonce'], - MessageType.DEFAULT.value - ) + data['tts'], + data['everyone_mention'], + + data['nonce'], + MessageType.DEFAULT.value, + data.get('embeds', []) + ) return message_id @@ -231,7 +237,6 @@ async def _create_message(channel_id): actual_guild_id = guild_id j = validate(await request.get_json(), MESSAGE_CREATE) - message_id = get_snowflake() # TODO: check connection to the gateway @@ -247,12 +252,14 @@ async def _create_message(channel_id): user_id, channel_id, 'send_tts_messages', False )) - await create_message(channel_id, actual_guild_id, user_id, { - 'content': j['content'], - 'tts': is_tts, - 'nonce': int(j.get('nonce', 0)), - 'everyone_mention': mentions_everyone or mentions_here, - }) + message_id = await create_message( + channel_id, actual_guild_id, user_id, { + 'content': j['content'], + 'tts': is_tts, + 'nonce': int(j.get('nonce', 0)), + 'everyone_mention': mentions_everyone or mentions_here, + 'embeds': [j['embed']] if 'embed' in j else [], + }) payload = await app.storage.get_message(message_id, user_id) diff --git a/litecord/embed/schemas.py b/litecord/embed/schemas.py index 1d80867..cace161 100644 --- a/litecord/embed/schemas.py +++ b/litecord/embed/schemas.py @@ -2,7 +2,6 @@ litecord.embed.schemas - embed input validators. """ import urllib.parse - from litecord.types import Color @@ -46,7 +45,7 @@ EMBED_AUTHOR = { 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False }, 'url': { - 'type': EmbedURL, 'required': False, + 'coerce': EmbedURL, 'required': False, }, 'icon_url': { 'coerce': EmbedURL, 'required': False, @@ -56,10 +55,7 @@ EMBED_AUTHOR = { EMBED_OBJECT = { 'title': { 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False}, - 'type': { - 'type': 'string', 'allowed': ['rich'], 'required': False, - 'default': 'rich' - }, + # NOTE: type set by us 'description': { 'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': False, }, diff --git a/litecord/schemas.py b/litecord/schemas.py index 42915c8..1240404 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -12,6 +12,7 @@ from .enums import ( MessageNotifications, ChannelType, VerificationLevel ) +from litecord.embed.schemas import EMBED_OBJECT log = Logger(__name__) @@ -372,6 +373,12 @@ MESSAGE_CREATE = { 'nonce': {'type': 'snowflake', 'required': False}, 'tts': {'type': 'boolean', 'required': False}, + 'embed': { + 'type': 'dict', + 'schema': EMBED_OBJECT, + 'required': False + } + # TODO: file, embed, payload_json } diff --git a/litecord/storage.py b/litecord/storage.py index 29cf405..f4cc1e5 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -12,6 +12,7 @@ from litecord.blueprints.channel.reactions import ( from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE from litecord.types import timestamp_ +from litecord.utils import pg_set_json log = Logger(__name__) @@ -33,24 +34,6 @@ def str_(val): return maybe(str, val) -async def _set_json(con): - """Set JSON and JSONB codecs for an - asyncpg connection.""" - await con.set_type_codec( - 'json', - encoder=json.dumps, - decoder=json.loads, - schema='pg_catalog' - ) - - await con.set_type_codec( - 'jsonb', - encoder=json.dumps, - decoder=json.loads, - schema='pg_catalog' - ) - - def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int): """Filter recipients in a list of recipients, removing the one that is reundant (ourselves).""" @@ -73,13 +56,13 @@ class Storage: # set_type_codec, so we must set it manually # by acquiring the connection async with self.db.acquire() as con: - await _set_json(con) + await pg_set_json(con) return await con.fetchrow(query, *args) async def fetch_with_json(self, query: str, *args): """Fetch many rows with JSON/JSONB support.""" async with self.db.acquire() as con: - await _set_json(con) + await pg_set_json(con) return await con.fetch(query, *args) async def get_user(self, user_id, secure=False) -> Dict[str, Any]: @@ -631,10 +614,10 @@ class Storage: async def get_message(self, message_id: int, user_id=None) -> Dict: """Get a single message's payload.""" - row = await self.db.fetchrow(""" + row = await self.fetchrow_with_json(""" SELECT id::text, channel_id::text, author_id, content, created_at AS timestamp, edited_at AS edited_timestamp, - tts, mention_everyone, nonce, message_type + tts, mention_everyone, nonce, message_type, embeds FROM messages WHERE id = $1 """, message_id) @@ -698,9 +681,6 @@ class Storage: # TODO: res['attachments'] res['attachments'] = [] - # TODO: res['embeds'] - res['embeds'] = [] - # TODO: res['member'] for partial member data # of the author diff --git a/litecord/utils.py b/litecord/utils.py index 26c6ba2..330e67e 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -1,4 +1,5 @@ import asyncio +import json from logbook import Logger log = Logger(__name__) @@ -110,3 +111,21 @@ def mmh3(key: str, seed: int = 0): h1 ^= _u(h1) >> 16 return _u(h1) >> 0 + + +async def pg_set_json(con): + """Set JSON and JSONB codecs for an + asyncpg connection.""" + await con.set_type_codec( + 'json', + encoder=json.dumps, + decoder=json.loads, + schema='pg_catalog' + ) + + await con.set_type_codec( + 'jsonb', + encoder=json.dumps, + decoder=json.loads, + schema='pg_catalog' + )