diff --git a/litecord/blueprints/channel/messages.py b/litecord/blueprints/channel/messages.py index cf5d552..b1ef399 100644 --- a/litecord/blueprints/channel/messages.py +++ b/litecord/blueprints/channel/messages.py @@ -12,6 +12,8 @@ from litecord.snowflake import get_snowflake from litecord.schemas import validate, MESSAGE_CREATE from litecord.utils import pg_set_json +from litecord.embed import sanitize_embed + log = Logger(__name__) bp = Blueprint('channel_messages', __name__) @@ -258,7 +260,7 @@ async def _create_message(channel_id): 'tts': is_tts, 'nonce': int(j.get('nonce', 0)), 'everyone_mention': mentions_everyone or mentions_here, - 'embeds': [j['embed']] if 'embed' in j else [], + 'embeds': [sanitize_embed(j['embed'])] if 'embed' in j else [], }) payload = await app.storage.get_message(message_id, user_id) diff --git a/litecord/embed/__init__.py b/litecord/embed/__init__.py new file mode 100644 index 0000000..b77baa8 --- /dev/null +++ b/litecord/embed/__init__.py @@ -0,0 +1,3 @@ +from .sanitizer import sanitize_embed + +__all__ = ['sanitize_embed'] diff --git a/litecord/embed/sanitizer.py b/litecord/embed/sanitizer.py new file mode 100644 index 0000000..eb9921f --- /dev/null +++ b/litecord/embed/sanitizer.py @@ -0,0 +1,68 @@ +""" +litecord.embed.sanitizer + sanitize embeds by giving common values + such as type: rich +""" +from typing import Dict, Any +from logbook import Logger + +from litecord.embed.schemas import EmbedURL + +log = Logger(__name__) +Embed = Dict[str, Any] + + +def _sane(v): + if isinstance(v, EmbedURL): + return v.to_json + + return v + + +def sanitize_embed(embed: Embed) -> Embed: + """Sanitize an embed object.""" + return {**embed, **{ + 'type': 'rich' + }} + + +def path_exists(embed: Embed, components: str): + """Tell if a given path exists in an embed. + + The components string is formatted like this: + key1.key2.key3.key4. <...> .keyN + + with each key going deeper and deeper into the embed. + """ + if isinstance(components, str): + components = components.split('.') + else: + components = list(components) + + if not components: + return True + + current = components[0] + + if current in embed: + return path_exists(embed[current], components[1:]) + + return False + + +async def fill_embed(embed: Embed) -> Embed: + """Fill an embed with more information.""" + embed = sanitize_embed(embed) + + if path_exists(embed, 'footer.icon_url'): + # TODO: make proxy_icon_url + log.warning('embed with footer.icon_url, ignoring') + + if path_exists(embed, 'image.url'): + # TODO: make proxy_icon_url, width, height + log.warning('embed with footer.image_url, ignoring') + + if path_exists(embed, 'author.icon_url'): + log.warning('embed with author.icon_url, ignoring') + + return embed diff --git a/litecord/embed/schemas.py b/litecord/embed/schemas.py index cace161..2caf82d 100644 --- a/litecord/embed/schemas.py +++ b/litecord/embed/schemas.py @@ -20,6 +20,10 @@ class EmbedURL: """Return the URL.""" return urllib.parse.urlunparse(self.parsed) + @property + def to_json(self): + return self.url + EMBED_FOOTER = { 'text': { diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 76bb68a..6c83512 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -13,7 +13,7 @@ import earl from litecord.auth import raw_token_check from litecord.enums import RelationshipType from litecord.schemas import validate, GW_STATUS_UPDATE -from litecord.utils import task_wrapper +from litecord.utils import task_wrapper, LitecordJSONEncoder from litecord.permissions import get_permissions from litecord.gateway.opcodes import OP @@ -39,7 +39,8 @@ WebsocketObjects = collections.namedtuple( def encode_json(payload) -> str: - return json.dumps(payload, separators=(',', ':')) + return json.dumps(payload, separators=(',', ':'), + cls=LitecordJSONEncoder) def decode_json(data: str): diff --git a/litecord/utils.py b/litecord/utils.py index 330e67e..1aa693e 100644 --- a/litecord/utils.py +++ b/litecord/utils.py @@ -1,6 +1,8 @@ import asyncio import json from logbook import Logger +from typing import Any +from quart.json import JSONEncoder log = Logger(__name__) @@ -113,19 +115,27 @@ def mmh3(key: str, seed: int = 0): return _u(h1) >> 0 +class LitecordJSONEncoder(JSONEncoder): + def default(self, value: Any): + try: + return value.to_json + except AttributeError: + return super().default(value) + + async def pg_set_json(con): """Set JSON and JSONB codecs for an asyncpg connection.""" await con.set_type_codec( 'json', - encoder=json.dumps, + encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder), decoder=json.loads, schema='pg_catalog' ) await con.set_type_codec( 'jsonb', - encoder=json.dumps, + encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder), decoder=json.loads, schema='pg_catalog' ) diff --git a/run.py b/run.py index 7479c79..a7185ce 100644 --- a/run.py +++ b/run.py @@ -50,6 +50,8 @@ from litecord.presence import PresenceManager from litecord.images import IconManager from litecord.jobs import JobManager +from litecord.utils import LitecordJSONEncoder + # setup logbook handler = StreamHandler(sys.stdout, level=logbook.INFO) handler.push_application() @@ -71,6 +73,9 @@ def make_app(): # always keep websockets on INFO logging.getLogger('websockets').setLevel(logbook.INFO) + # use our custom json encoder for custom data types + app.json_encoder = LitecordJSONEncoder + return app