From 8b9719540455145f7bfd42be379fd9faf66a0f51 Mon Sep 17 00:00:00 2001 From: Luna Date: Tue, 4 Dec 2018 18:10:58 -0300 Subject: [PATCH 1/6] litecord: add embed namespace - embed: add embed.schemas - channel.messages: split some functions for readability --- litecord/blueprints/channel/messages.py | 157 ++++++++++++++---------- litecord/embed/schemas.py | 108 ++++++++++++++++ 2 files changed, 198 insertions(+), 67 deletions(-) create mode 100644 litecord/embed/schemas.py diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 5bdd925..4bd0206 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -139,8 +139,88 @@ async def _dm_pre_dispatch(channel_id, peer_id): await try_dm_state(peer_id, channel_id) +async def create_message(channel_id: int, actual_guild_id: int, + author_id: int, data: dict) -> int: + message_id = get_snowflake() + + await app.db.execute( + """ + INSERT INTO messages (id, channel_id, guild_id, author_id, + content, tts, mention_everyone, nonce, message_type) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + """, + message_id, + channel_id, + actual_guild_id, + author_id, + data['content'], + + data['tts'], + data['everyone_mention'], + + data['nonce'], + MessageType.DEFAULT.value + ) + + return message_id + +async def _guild_text_mentions(payload: dict, guild_id: int, + mentions_everyone: bool, mentions_here: bool): + channel_id = int(payload['channel_id']) + + # calculate the user ids we'll bump the mention count for + uids = set() + + # first is extracting user mentions + for mention in payload['mentions']: + uids.add(int(mention['id'])) + + # then role mentions + for role_mention in payload['mention_roles']: + role_id = int(role_mention) + member_ids = await app.storage.get_role_members(role_id) + + for member_id in member_ids: + uids.add(member_id) + + # at-here only updates the state + # for the users that have a state + # in the channel. + if mentions_here: + uids = [] + await app.db.execute(""" + UPDATE user_read_state + SET mention_count = mention_count + 1 + WHERE channel_id = $1 + """, channel_id) + + # at-here updates the read state + # for all users, including the ones + # that might not have read permissions + # to the channel. + if mentions_everyone: + uids = [] + + member_ids = await app.storage.get_member_ids(guild_id) + + await app.db.executemany(""" + UPDATE user_read_state + SET mention_count = mention_count + 1 + WHERE channel_id = $1 AND user_id = $2 + """, [(channel_id, uid) for uid in member_ids]) + + for user_id in uids: + await app.db.execute(""" + UPDATE user_read_state + SET mention_count = mention_count + 1 + WHERE user_id = $1 + AND channel_id = $2 + """, user_id, channel_id) + + @bp.route('//messages', methods=['POST']) -async def create_message(channel_id): +async def _create_message(channel_id): + """Create a message.""" user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) @@ -167,24 +247,12 @@ async def create_message(channel_id): user_id, channel_id, 'send_tts_messages', False )) - await app.db.execute( - """ - INSERT INTO messages (id, channel_id, guild_id, author_id, - content, tts, mention_everyone, nonce, message_type) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """, - message_id, - channel_id, - actual_guild_id, - user_id, - j['content'], - - is_tts, - mentions_everyone or mentions_here, - - int(j.get('nonce', 0)), - MessageType.DEFAULT.value - ) + await create_message(channel_id, actual_guild_id, user_id, { + 'content': j['content'], + 'tts': is_tts, + 'nonce': int(j.get('nonce', 0)), + 'everyone_mention': mentions_everyone or mentions_here, + }) payload = await app.storage.get_message(message_id, user_id) @@ -196,6 +264,7 @@ async def create_message(channel_id): await app.dispatcher.dispatch('channel', channel_id, 'MESSAGE_CREATE', payload) + # update read state for the author await app.db.execute(""" UPDATE user_read_state SET last_message_id = $1 @@ -203,54 +272,8 @@ async def create_message(channel_id): """, message_id, channel_id, user_id) if ctype == ChannelType.GUILD_TEXT: - # calculate the user ids we'll bump the mention count for - uids = set() - - # first is extracting user mentions - for mention in payload['mentions']: - uids.add(int(mention['id'])) - - # then role mentions - for role_mention in payload['mention_roles']: - role_id = int(role_mention) - member_ids = await app.storage.get_role_members(role_id) - - for member_id in member_ids: - uids.add(member_id) - - # at-here only updates the state - # for the users that have a state - # in the channel. - if mentions_here: - uids = [] - await app.db.execute(""" - UPDATE user_read_state - SET mention_count = mention_count + 1 - WHERE channel_id = $1 - """, channel_id) - - # at-here updates the read state - # for all users, including the ones - # that might not have read permissions - # to the channel. - if mentions_everyone: - uids = [] - - member_ids = await app.storage.get_member_ids(guild_id) - - await app.db.executemany(""" - UPDATE user_read_state - SET mention_count = mention_count + 1 - WHERE channel_id = $1 AND user_id = $2 - """, [(channel_id, uid) for uid in member_ids]) - - for user_id in uids: - await app.db.execute(""" - UPDATE user_read_state - SET mention_count = mention_count + 1 - WHERE user_id = $1 - AND channel_id = $2 - """, user_id, channel_id) + await _guild_text_mentions(payload, guild_id, + mentions_everyone, mentions_here) return jsonify(payload) diff --git a/litecord/embed/schemas.py b/litecord/embed/schemas.py new file mode 100644 index 0000000..1d80867 --- /dev/null +++ b/litecord/embed/schemas.py @@ -0,0 +1,108 @@ +""" +litecord.embed.schemas - embed input validators. +""" +import urllib.parse + +from litecord.types import Color + + +class EmbedURL: + def __init__(self, url: str): + parsed = urllib.parse.urlparse(url) + + if parsed.scheme not in ('http', 'https', 'attachment'): + raise ValueError('Invalid URL scheme') + + self.raw_url = url + self.parsed = parsed + + @property + def url(self): + """Return the URL.""" + return urllib.parse.urlunparse(self.parsed) + + +EMBED_FOOTER = { + 'text': { + 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': True}, + + 'icon_url': { + 'coerce': EmbedURL, 'required': False, + }, + + # NOTE: proxy_icon_url set by us +} + +EMBED_IMAGE = { + 'url': {'coerce': EmbedURL, 'required': True}, + + # NOTE: proxy_url, width, height set by us +} + +EMBED_THUMBNAIL = EMBED_IMAGE + +EMBED_AUTHOR = { + 'name': { + 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False + }, + 'url': { + 'type': EmbedURL, 'required': False, + }, + 'icon_url': { + 'coerce': EmbedURL, 'required': False, + } +} + +EMBED_OBJECT = { + 'title': { + 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False}, + 'type': { + 'type': 'string', 'allowed': ['rich'], 'required': False, + 'default': 'rich' + }, + 'description': { + 'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': False, + }, + 'url': { + 'coerce': EmbedURL, 'required': False, + }, + 'timestamp': { + # TODO: an ISO 8601 type + # TODO: maybe replace the default in here with now().isoformat? + 'type': 'string', 'required': False + }, + + 'color': { + 'coerce': Color, 'required': False + }, + + 'footer': { + 'type': 'dict', + 'schema': EMBED_FOOTER, + 'required': False, + }, + 'image': { + 'type': 'dict', + 'schema': EMBED_IMAGE, + 'required': False, + }, + 'thumbnail': { + 'type': 'dict', + 'schema': EMBED_THUMBNAIL, + 'required': False, + }, + + # NOTE: 'video' set by us + # NOTE: 'provider' set by us + + 'author': { + 'type': 'dict', + 'schema': EMBED_AUTHOR, + 'required': False, + }, + 'fields': { + 'type': 'list', + 'schema': EMBED_AUTHOR, + 'required': False, + }, +} From 94fe51ac69c38da665b414d93f1a478a71239f17 Mon Sep 17 00:00:00 2001 From: Luna Date: Tue, 4 Dec 2018 18:37:42 -0300 Subject: [PATCH 2/6] channel.messages: add embed insertion - storage: move pg_set_json to litecord.utils to fix circular imports - storage: add embed fetch to get_message - embed.schemas: fix author.url's url - schemas: add EMBED_OBJECT to MESSAGE_CREATE --- litecord/blueprints/channel/messages.py | 53 ++++++++++++++----------- litecord/embed/schemas.py | 8 +--- litecord/schemas.py | 7 ++++ litecord/storage.py | 30 +++----------- litecord/utils.py | 19 +++++++++ 5 files changed, 63 insertions(+), 54 deletions(-) diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index 4bd0206..cf5d552 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -10,6 +10,7 @@ from litecord.errors import MessageNotFound, Forbidden, BadRequest from litecord.enums import MessageType, ChannelType, GUILD_CHANS from litecord.snowflake import get_snowflake from litecord.schemas import validate, MESSAGE_CREATE +from litecord.utils import pg_set_json log = Logger(__name__) @@ -143,24 +144,29 @@ async def create_message(channel_id: int, actual_guild_id: int, author_id: int, data: dict) -> int: message_id = get_snowflake() - await app.db.execute( - """ - INSERT INTO messages (id, channel_id, guild_id, author_id, - content, tts, mention_everyone, nonce, message_type) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) - """, - message_id, - channel_id, - actual_guild_id, - author_id, - data['content'], + async with app.db.acquire() as conn: + await pg_set_json(conn) - data['tts'], - data['everyone_mention'], + await conn.execute( + """ + INSERT INTO messages (id, channel_id, guild_id, author_id, + content, tts, mention_everyone, nonce, message_type, + embeds) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + """, + message_id, + channel_id, + actual_guild_id, + author_id, + data['content'], - data['nonce'], - MessageType.DEFAULT.value - ) + data['tts'], + data['everyone_mention'], + + data['nonce'], + MessageType.DEFAULT.value, + data.get('embeds', []) + ) return message_id @@ -231,7 +237,6 @@ async def _create_message(channel_id): actual_guild_id = guild_id j = validate(await request.get_json(), MESSAGE_CREATE) - message_id = get_snowflake() # TODO: check connection to the gateway @@ -247,12 +252,14 @@ async def _create_message(channel_id): user_id, channel_id, 'send_tts_messages', False )) - await create_message(channel_id, actual_guild_id, user_id, { - 'content': j['content'], - 'tts': is_tts, - 'nonce': int(j.get('nonce', 0)), - 'everyone_mention': mentions_everyone or mentions_here, - }) + message_id = await create_message( + channel_id, actual_guild_id, user_id, { + 'content': j['content'], + 'tts': is_tts, + 'nonce': int(j.get('nonce', 0)), + 'everyone_mention': mentions_everyone or mentions_here, + 'embeds': [j['embed']] if 'embed' in j else [], + }) payload = await app.storage.get_message(message_id, user_id) diff --git a/litecord/embed/schemas.py b/litecord/embed/schemas.py index 1d80867..cace161 100644 --- a/litecord/embed/schemas.py +++ b/litecord/embed/schemas.py @@ -2,7 +2,6 @@ litecord.embed.schemas - embed input validators. """ import urllib.parse - from litecord.types import Color @@ -46,7 +45,7 @@ EMBED_AUTHOR = { 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False }, 'url': { - 'type': EmbedURL, 'required': False, + 'coerce': EmbedURL, 'required': False, }, 'icon_url': { 'coerce': EmbedURL, 'required': False, @@ -56,10 +55,7 @@ EMBED_AUTHOR = { EMBED_OBJECT = { 'title': { 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False}, - 'type': { - 'type': 'string', 'allowed': ['rich'], 'required': False, - 'default': 'rich' - }, + # NOTE: type set by us 'description': { 'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': False, }, diff --git a/litecord/schemas.py b/litecord/schemas.py index 42915c8..1240404 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -12,6 +12,7 @@ from .enums import ( MessageNotifications, ChannelType, VerificationLevel ) +from litecord.embed.schemas import EMBED_OBJECT log = Logger(__name__) @@ -372,6 +373,12 @@ MESSAGE_CREATE = { 'nonce': {'type': 'snowflake', 'required': False}, 'tts': {'type': 'boolean', 'required': False}, + 'embed': { + 'type': 'dict', + 'schema': EMBED_OBJECT, + 'required': False + } + # TODO: file, embed, payload_json } diff --git a/litecord/storage.py b/litecord/storage.py index 29cf405..f4cc1e5 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -12,6 +12,7 @@ from litecord.blueprints.channel.reactions import ( from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE from litecord.types import timestamp_ +from litecord.utils import pg_set_json log = Logger(__name__) @@ -33,24 +34,6 @@ def str_(val): return maybe(str, val) -async def _set_json(con): - """Set JSON and JSONB codecs for an - asyncpg connection.""" - await con.set_type_codec( - 'json', - encoder=json.dumps, - decoder=json.loads, - schema='pg_catalog' - ) - - await con.set_type_codec( - 'jsonb', - encoder=json.dumps, - decoder=json.loads, - schema='pg_catalog' - ) - - def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int): """Filter recipients in a list of recipients, removing the one that is reundant (ourselves).""" @@ -73,13 +56,13 @@ class Storage: # set_type_codec, so we must set it manually # by acquiring the connection async with self.db.acquire() as con: - await _set_json(con) + await pg_set_json(con) return await con.fetchrow(query, *args) async def fetch_with_json(self, query: str, *args): """Fetch many rows with JSON/JSONB support.""" async with self.db.acquire() as con: - await _set_json(con) + await pg_set_json(con) return await con.fetch(query, *args) async def get_user(self, user_id, secure=False) -> Dict[str, Any]: @@ -631,10 +614,10 @@ class Storage: async def get_message(self, message_id: int, user_id=None) -> Dict: """Get a single message's payload.""" - row = await self.db.fetchrow(""" + row = await self.fetchrow_with_json(""" SELECT id::text, channel_id::text, author_id, content, created_at AS timestamp, edited_at AS edited_timestamp, - tts, mention_everyone, nonce, message_type + tts, mention_everyone, nonce, message_type, embeds FROM messages WHERE id = $1 """, message_id) @@ -698,9 +681,6 @@ class Storage: # TODO: res['attachments'] res['attachments'] = [] - # TODO: res['embeds'] - res['embeds'] = [] - # TODO: res['member'] for partial member data # of the author diff --git a/litecord/utils.py b/litecord/utils.py index 26c6ba2..330e67e 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -1,4 +1,5 @@ import asyncio +import json from logbook import Logger log = Logger(__name__) @@ -110,3 +111,21 @@ def mmh3(key: str, seed: int = 0): h1 ^= _u(h1) >> 16 return _u(h1) >> 0 + + +async def pg_set_json(con): + """Set JSON and JSONB codecs for an + asyncpg connection.""" + await con.set_type_codec( + 'json', + encoder=json.dumps, + decoder=json.loads, + schema='pg_catalog' + ) + + await con.set_type_codec( + 'jsonb', + encoder=json.dumps, + decoder=json.loads, + schema='pg_catalog' + ) From 5f6ddad54d00b19c736b153594989d7496d03d40 Mon Sep 17 00:00:00 2001 From: Luna Date: Tue, 4 Dec 2018 18:45:01 -0300 Subject: [PATCH 3/6] storage: remove unused json import --- litecord/storage.py | 1 - 1 file changed, 1 deletion(-) diff --git a/litecord/storage.py b/litecord/storage.py index f4cc1e5..c7d76f2 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -1,4 +1,3 @@ -import json from typing import List, Dict, Any from logbook import Logger From 5db633b79754001ea78fece4157137b9bb95beaa Mon Sep 17 00:00:00 2001 From: Luna Date: Tue, 4 Dec 2018 21:45:14 -0300 Subject: [PATCH 4/6] channel.messages: use sanitize_embed - embed: add sanitizer module - embed.schemas: add to_json to EmbedURL - utils: add custom JSON encoder - run: use custom JSON encoder - gateway.websocket: use custom JSON encoder --- litecord/blueprints/channel/messages.py | 4 +- litecord/embed/__init__.py | 3 ++ litecord/embed/sanitizer.py | 68 +++++++++++++++++++++++++ litecord/embed/schemas.py | 4 ++ litecord/gateway/websocket.py | 5 +- litecord/utils.py | 14 ++++- run.py | 5 ++ 7 files changed, 98 insertions(+), 5 deletions(-) create mode 100644 litecord/embed/__init__.py create mode 100644 litecord/embed/sanitizer.py diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index cf5d552..b1ef399 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -12,6 +12,8 @@ from litecord.snowflake import get_snowflake from litecord.schemas import validate, MESSAGE_CREATE from litecord.utils import pg_set_json +from litecord.embed import sanitize_embed + log = Logger(__name__) bp = Blueprint('channel_messages', __name__) @@ -258,7 +260,7 @@ async def _create_message(channel_id): 'tts': is_tts, 'nonce': int(j.get('nonce', 0)), 'everyone_mention': mentions_everyone or mentions_here, - 'embeds': [j['embed']] if 'embed' in j else [], + 'embeds': [sanitize_embed(j['embed'])] if 'embed' in j else [], }) payload = await app.storage.get_message(message_id, user_id) diff --git a/litecord/embed/__init__.py b/litecord/embed/__init__.py new file mode 100644 index 0000000..b77baa8 --- /dev/null +++ b/litecord/embed/__init__.py @@ -0,0 +1,3 @@ +from .sanitizer import sanitize_embed + +__all__ = ['sanitize_embed'] diff --git a/litecord/embed/sanitizer.py b/litecord/embed/sanitizer.py new file mode 100644 index 0000000..eb9921f --- /dev/null +++ b/litecord/embed/sanitizer.py @@ -0,0 +1,68 @@ +""" +litecord.embed.sanitizer + sanitize embeds by giving common values + such as type: rich +""" +from typing import Dict, Any +from logbook import Logger + +from litecord.embed.schemas import EmbedURL + +log = Logger(__name__) +Embed = Dict[str, Any] + + +def _sane(v): + if isinstance(v, EmbedURL): + return v.to_json + + return v + + +def sanitize_embed(embed: Embed) -> Embed: + """Sanitize an embed object.""" + return {**embed, **{ + 'type': 'rich' + }} + + +def path_exists(embed: Embed, components: str): + """Tell if a given path exists in an embed. + + The components string is formatted like this: + key1.key2.key3.key4. <...> .keyN + + with each key going deeper and deeper into the embed. + """ + if isinstance(components, str): + components = components.split('.') + else: + components = list(components) + + if not components: + return True + + current = components[0] + + if current in embed: + return path_exists(embed[current], components[1:]) + + return False + + +async def fill_embed(embed: Embed) -> Embed: + """Fill an embed with more information.""" + embed = sanitize_embed(embed) + + if path_exists(embed, 'footer.icon_url'): + # TODO: make proxy_icon_url + log.warning('embed with footer.icon_url, ignoring') + + if path_exists(embed, 'image.url'): + # TODO: make proxy_icon_url, width, height + log.warning('embed with footer.image_url, ignoring') + + if path_exists(embed, 'author.icon_url'): + log.warning('embed with author.icon_url, ignoring') + + return embed diff --git a/litecord/embed/schemas.py b/litecord/embed/schemas.py index cace161..2caf82d 100644 --- a/litecord/embed/schemas.py +++ b/litecord/embed/schemas.py @@ -20,6 +20,10 @@ class EmbedURL: """Return the URL.""" return urllib.parse.urlunparse(self.parsed) + @property + def to_json(self): + return self.url + EMBED_FOOTER = { 'text': { diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 76bb68a..6c83512 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -13,7 +13,7 @@ import earl from litecord.auth import raw_token_check from litecord.enums import RelationshipType from litecord.schemas import validate, GW_STATUS_UPDATE -from litecord.utils import task_wrapper +from litecord.utils import task_wrapper, LitecordJSONEncoder from litecord.permissions import get_permissions from litecord.gateway.opcodes import OP @@ -39,7 +39,8 @@ WebsocketObjects = collections.namedtuple( def encode_json(payload) -> str: - return json.dumps(payload, separators=(',', ':')) + return json.dumps(payload, separators=(',', ':'), + cls=LitecordJSONEncoder) def decode_json(data: str): diff --git a/litecord/utils.py b/litecord/utils.py index 330e67e..1aa693e 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -1,6 +1,8 @@ import asyncio import json from logbook import Logger +from typing import Any +from quart.json import JSONEncoder log = Logger(__name__) @@ -113,19 +115,27 @@ def mmh3(key: str, seed: int = 0): return _u(h1) >> 0 +class LitecordJSONEncoder(JSONEncoder): + def default(self, value: Any): + try: + return value.to_json + except AttributeError: + return super().default(value) + + async def pg_set_json(con): """Set JSON and JSONB codecs for an asyncpg connection.""" await con.set_type_codec( 'json', - encoder=json.dumps, + encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder), decoder=json.loads, schema='pg_catalog' ) await con.set_type_codec( 'jsonb', - encoder=json.dumps, + encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder), decoder=json.loads, schema='pg_catalog' ) diff --git a/run.py b/run.py index 7479c79..a7185ce 100644 --- a/run.py +++ b/run.py @@ -50,6 +50,8 @@ from litecord.presence import PresenceManager from litecord.images import IconManager from litecord.jobs import JobManager +from litecord.utils import LitecordJSONEncoder + # setup logbook handler = StreamHandler(sys.stdout, level=logbook.INFO) handler.push_application() @@ -71,6 +73,9 @@ def make_app(): # always keep websockets on INFO logging.getLogger('websockets').setLevel(logbook.INFO) + # use our custom json encoder for custom data types + app.json_encoder = LitecordJSONEncoder + return app From 5de64a93ee113c23b3e9291a1d5d960c2bb0f9b8 Mon Sep 17 00:00:00 2001 From: Luna Date: Tue, 4 Dec 2018 21:50:11 -0300 Subject: [PATCH 5/6] embed.sanitizer: remove unused _sane function --- litecord/embed/sanitizer.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/litecord/embed/sanitizer.py b/litecord/embed/sanitizer.py index eb9921f..be3e1be 100644 --- a/litecord/embed/sanitizer.py +++ b/litecord/embed/sanitizer.py @@ -12,41 +12,46 @@ log = Logger(__name__) Embed = Dict[str, Any] -def _sane(v): - if isinstance(v, EmbedURL): - return v.to_json - - return v - - def sanitize_embed(embed: Embed) -> Embed: - """Sanitize an embed object.""" + """Sanitize an embed object. + + This is non-complex sanitization as it doesn't + need the app object. + """ return {**embed, **{ 'type': 'rich' }} def path_exists(embed: Embed, components: str): - """Tell if a given path exists in an embed. + """Tell if a given path exists in an embed (or any dictionary). The components string is formatted like this: key1.key2.key3.key4. <...> .keyN with each key going deeper and deeper into the embed. """ + + # get the list of components given if isinstance(components, str): components = components.split('.') else: components = list(components) + # if there are no components, we reached the end of recursion + # and can return true if not components: return True + # extract current component current = components[0] + # if it exists, then we go down a level inside the dict + # (via recursion) if current in embed: return path_exists(embed[current], components[1:]) + # if it doesn't exist, return False return False @@ -63,6 +68,8 @@ async def fill_embed(embed: Embed) -> Embed: log.warning('embed with footer.image_url, ignoring') if path_exists(embed, 'author.icon_url'): + # TODO: should we check icon_url and convert it into + # a proxied icon url? log.warning('embed with author.icon_url, ignoring') return embed From 9b5902db95eb23125a706123e0c0be29bef98b47 Mon Sep 17 00:00:00 2001 From: Luna Date: Tue, 4 Dec 2018 22:54:39 -0300 Subject: [PATCH 6/6] tests: add test_embeds - embeds.schemas: add EMBED_FIELD and EMBED_OBJECT.fields to use it --- litecord/embed/sanitizer.py | 2 - litecord/embed/schemas.py | 15 ++++++- tests/test_embeds.py | 83 +++++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 tests/test_embeds.py diff --git a/litecord/embed/sanitizer.py b/litecord/embed/sanitizer.py index be3e1be..3051cd3 100644 --- a/litecord/embed/sanitizer.py +++ b/litecord/embed/sanitizer.py @@ -6,8 +6,6 @@ litecord.embed.sanitizer from typing import Dict, Any from logbook import Logger -from litecord.embed.schemas import EmbedURL - log = Logger(__name__) Embed = Dict[str, Any] diff --git a/litecord/embed/schemas.py b/litecord/embed/schemas.py index 2caf82d..39b5881 100644 --- a/litecord/embed/schemas.py +++ b/litecord/embed/schemas.py @@ -56,6 +56,18 @@ EMBED_AUTHOR = { } } +EMBED_FIELD = { + 'name': { + 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': True + }, + 'value': { + 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': True + }, + 'inline': { + 'type': 'boolean', 'required': False, 'default': True, + }, +} + EMBED_OBJECT = { 'title': { 'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False}, @@ -100,9 +112,10 @@ EMBED_OBJECT = { 'schema': EMBED_AUTHOR, 'required': False, }, + 'fields': { 'type': 'list', - 'schema': EMBED_AUTHOR, + 'schema': {'type': 'dict', 'schema': EMBED_FIELD}, 'required': False, }, } diff --git a/tests/test_embeds.py b/tests/test_embeds.py new file mode 100644 index 0000000..0374b7c --- /dev/null +++ b/tests/test_embeds.py @@ -0,0 +1,83 @@ +from litecord.schemas import validate +from litecord.embed.schemas import EMBED_OBJECT + +def validate_embed(embed): + return validate(embed, EMBED_OBJECT) + +def valid(embed: dict): + try: + validate_embed(embed) + return True + except: + return False + +def invalid(embed): + try: + validate_embed(embed) + return False + except: + return True + + +def test_empty_embed(): + valid({}) + + +def test_basic_embed(): + assert valid({ + 'title': 'test', + 'description': 'acab', + 'url': 'https://www.w3.org', + 'color': 123 + }) + + +def test_footer_embed(): + assert invalid({ + 'footer': {} + }) + + assert valid({ + 'title': 'test', + 'footer': { + 'text': 'abcdef' + } + }) + +def test_image(): + assert invalid({ + 'image': {} + }) + + assert valid({ + 'image': { + 'url': 'https://www.w3.org' + } + }) + +def test_author(): + assert invalid({ + 'author': { + 'name': '' + } + }) + + assert valid({ + 'author': { + 'name': 'abcdef' + } + }) + +def test_fields(): + assert valid({ + 'fields': [ + {'name': 'a', 'value': 'b'}, + {'name': 'c', 'value': 'd', 'inline': False}, + ] + }) + + valid({ + 'fields': [ + {'name': 'a'}, + ] + })