channel.messages: add embed insertion

- storage: move pg_set_json to litecord.utils to fix circular imports
 - storage: add embed fetch to get_message
 - embed.schemas: fix author.url's url
 - schemas: add EMBED_OBJECT to MESSAGE_CREATE
This commit is contained in:
Luna 2018-12-04 18:37:42 -03:00
parent 8b97195404
commit 94fe51ac69
5 changed files with 63 additions and 54 deletions

View File

@ -10,6 +10,7 @@ from litecord.errors import MessageNotFound, Forbidden, BadRequest
from litecord.enums import MessageType, ChannelType, GUILD_CHANS from litecord.enums import MessageType, ChannelType, GUILD_CHANS
from litecord.snowflake import get_snowflake from litecord.snowflake import get_snowflake
from litecord.schemas import validate, MESSAGE_CREATE from litecord.schemas import validate, MESSAGE_CREATE
from litecord.utils import pg_set_json
log = Logger(__name__) log = Logger(__name__)
@ -143,11 +144,15 @@ async def create_message(channel_id: int, actual_guild_id: int,
author_id: int, data: dict) -> int: author_id: int, data: dict) -> int:
message_id = get_snowflake() message_id = get_snowflake()
await app.db.execute( async with app.db.acquire() as conn:
await pg_set_json(conn)
await conn.execute(
""" """
INSERT INTO messages (id, channel_id, guild_id, author_id, INSERT INTO messages (id, channel_id, guild_id, author_id,
content, tts, mention_everyone, nonce, message_type) content, tts, mention_everyone, nonce, message_type,
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) embeds)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
""", """,
message_id, message_id,
channel_id, channel_id,
@ -159,7 +164,8 @@ async def create_message(channel_id: int, actual_guild_id: int,
data['everyone_mention'], data['everyone_mention'],
data['nonce'], data['nonce'],
MessageType.DEFAULT.value MessageType.DEFAULT.value,
data.get('embeds', [])
) )
return message_id return message_id
@ -231,7 +237,6 @@ async def _create_message(channel_id):
actual_guild_id = guild_id actual_guild_id = guild_id
j = validate(await request.get_json(), MESSAGE_CREATE) j = validate(await request.get_json(), MESSAGE_CREATE)
message_id = get_snowflake()
# TODO: check connection to the gateway # TODO: check connection to the gateway
@ -247,11 +252,13 @@ async def _create_message(channel_id):
user_id, channel_id, 'send_tts_messages', False user_id, channel_id, 'send_tts_messages', False
)) ))
await create_message(channel_id, actual_guild_id, user_id, { message_id = await create_message(
channel_id, actual_guild_id, user_id, {
'content': j['content'], 'content': j['content'],
'tts': is_tts, 'tts': is_tts,
'nonce': int(j.get('nonce', 0)), 'nonce': int(j.get('nonce', 0)),
'everyone_mention': mentions_everyone or mentions_here, 'everyone_mention': mentions_everyone or mentions_here,
'embeds': [j['embed']] if 'embed' in j else [],
}) })
payload = await app.storage.get_message(message_id, user_id) payload = await app.storage.get_message(message_id, user_id)

View File

@ -2,7 +2,6 @@
litecord.embed.schemas - embed input validators. litecord.embed.schemas - embed input validators.
""" """
import urllib.parse import urllib.parse
from litecord.types import Color from litecord.types import Color
@ -46,7 +45,7 @@ EMBED_AUTHOR = {
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False
}, },
'url': { 'url': {
'type': EmbedURL, 'required': False, 'coerce': EmbedURL, 'required': False,
}, },
'icon_url': { 'icon_url': {
'coerce': EmbedURL, 'required': False, 'coerce': EmbedURL, 'required': False,
@ -56,10 +55,7 @@ EMBED_AUTHOR = {
EMBED_OBJECT = { EMBED_OBJECT = {
'title': { 'title': {
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False}, 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False},
'type': { # NOTE: type set by us
'type': 'string', 'allowed': ['rich'], 'required': False,
'default': 'rich'
},
'description': { 'description': {
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': False, 'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': False,
}, },

View File

@ -12,6 +12,7 @@ from .enums import (
MessageNotifications, ChannelType, VerificationLevel MessageNotifications, ChannelType, VerificationLevel
) )
from litecord.embed.schemas import EMBED_OBJECT
log = Logger(__name__) log = Logger(__name__)
@ -372,6 +373,12 @@ MESSAGE_CREATE = {
'nonce': {'type': 'snowflake', 'required': False}, 'nonce': {'type': 'snowflake', 'required': False},
'tts': {'type': 'boolean', 'required': False}, 'tts': {'type': 'boolean', 'required': False},
'embed': {
'type': 'dict',
'schema': EMBED_OBJECT,
'required': False
}
# TODO: file, embed, payload_json # TODO: file, embed, payload_json
} }

View File

@ -12,6 +12,7 @@ from litecord.blueprints.channel.reactions import (
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
from litecord.types import timestamp_ from litecord.types import timestamp_
from litecord.utils import pg_set_json
log = Logger(__name__) log = Logger(__name__)
@ -33,24 +34,6 @@ def str_(val):
return maybe(str, val) return maybe(str, val)
async def _set_json(con):
"""Set JSON and JSONB codecs for an
asyncpg connection."""
await con.set_type_codec(
'json',
encoder=json.dumps,
decoder=json.loads,
schema='pg_catalog'
)
await con.set_type_codec(
'jsonb',
encoder=json.dumps,
decoder=json.loads,
schema='pg_catalog'
)
def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int): def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int):
"""Filter recipients in a list of recipients, removing """Filter recipients in a list of recipients, removing
the one that is reundant (ourselves).""" the one that is reundant (ourselves)."""
@ -73,13 +56,13 @@ class Storage:
# set_type_codec, so we must set it manually # set_type_codec, so we must set it manually
# by acquiring the connection # by acquiring the connection
async with self.db.acquire() as con: async with self.db.acquire() as con:
await _set_json(con) await pg_set_json(con)
return await con.fetchrow(query, *args) return await con.fetchrow(query, *args)
async def fetch_with_json(self, query: str, *args): async def fetch_with_json(self, query: str, *args):
"""Fetch many rows with JSON/JSONB support.""" """Fetch many rows with JSON/JSONB support."""
async with self.db.acquire() as con: async with self.db.acquire() as con:
await _set_json(con) await pg_set_json(con)
return await con.fetch(query, *args) return await con.fetch(query, *args)
async def get_user(self, user_id, secure=False) -> Dict[str, Any]: async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
@ -631,10 +614,10 @@ class Storage:
async def get_message(self, message_id: int, user_id=None) -> Dict: async def get_message(self, message_id: int, user_id=None) -> Dict:
"""Get a single message's payload.""" """Get a single message's payload."""
row = await self.db.fetchrow(""" row = await self.fetchrow_with_json("""
SELECT id::text, channel_id::text, author_id, content, SELECT id::text, channel_id::text, author_id, content,
created_at AS timestamp, edited_at AS edited_timestamp, created_at AS timestamp, edited_at AS edited_timestamp,
tts, mention_everyone, nonce, message_type tts, mention_everyone, nonce, message_type, embeds
FROM messages FROM messages
WHERE id = $1 WHERE id = $1
""", message_id) """, message_id)
@ -698,9 +681,6 @@ class Storage:
# TODO: res['attachments'] # TODO: res['attachments']
res['attachments'] = [] res['attachments'] = []
# TODO: res['embeds']
res['embeds'] = []
# TODO: res['member'] for partial member data # TODO: res['member'] for partial member data
# of the author # of the author

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import json
from logbook import Logger from logbook import Logger
log = Logger(__name__) log = Logger(__name__)
@ -110,3 +111,21 @@ def mmh3(key: str, seed: int = 0):
h1 ^= _u(h1) >> 16 h1 ^= _u(h1) >> 16
return _u(h1) >> 0 return _u(h1) >> 0
async def pg_set_json(con):
"""Set JSON and JSONB codecs for an
asyncpg connection."""
await con.set_type_codec(
'json',
encoder=json.dumps,
decoder=json.loads,
schema='pg_catalog'
)
await con.set_type_codec(
'jsonb',
encoder=json.dumps,
decoder=json.loads,
schema='pg_catalog'
)