mirror of https://gitlab.com/litecord/litecord.git
channel.messages: use sanitize_embed
- embed: add sanitizer module - embed.schemas: add to_json to EmbedURL - utils: add custom JSON encoder - run: use custom JSON encoder - gateway.websocket: use custom JSON encoder
This commit is contained in:
parent
5f6ddad54d
commit
5db633b797
|
|
@ -12,6 +12,8 @@ from litecord.snowflake import get_snowflake
|
|||
from litecord.schemas import validate, MESSAGE_CREATE
|
||||
from litecord.utils import pg_set_json
|
||||
|
||||
from litecord.embed import sanitize_embed
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
bp = Blueprint('channel_messages', __name__)
|
||||
|
|
@ -258,7 +260,7 @@ async def _create_message(channel_id):
|
|||
'tts': is_tts,
|
||||
'nonce': int(j.get('nonce', 0)),
|
||||
'everyone_mention': mentions_everyone or mentions_here,
|
||||
'embeds': [j['embed']] if 'embed' in j else [],
|
||||
'embeds': [sanitize_embed(j['embed'])] if 'embed' in j else [],
|
||||
})
|
||||
|
||||
payload = await app.storage.get_message(message_id, user_id)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from .sanitizer import sanitize_embed
|
||||
|
||||
__all__ = ['sanitize_embed']
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
"""
|
||||
litecord.embed.sanitizer
|
||||
sanitize embeds by giving common values
|
||||
such as type: rich
|
||||
"""
|
||||
from typing import Dict, Any
|
||||
from logbook import Logger
|
||||
|
||||
from litecord.embed.schemas import EmbedURL
|
||||
|
||||
log = Logger(__name__)
|
||||
Embed = Dict[str, Any]
|
||||
|
||||
|
||||
def _sane(v):
|
||||
if isinstance(v, EmbedURL):
|
||||
return v.to_json
|
||||
|
||||
return v
|
||||
|
||||
|
||||
def sanitize_embed(embed: Embed) -> Embed:
|
||||
"""Sanitize an embed object."""
|
||||
return {**embed, **{
|
||||
'type': 'rich'
|
||||
}}
|
||||
|
||||
|
||||
def path_exists(embed: Embed, components: str):
|
||||
"""Tell if a given path exists in an embed.
|
||||
|
||||
The components string is formatted like this:
|
||||
key1.key2.key3.key4. <...> .keyN
|
||||
|
||||
with each key going deeper and deeper into the embed.
|
||||
"""
|
||||
if isinstance(components, str):
|
||||
components = components.split('.')
|
||||
else:
|
||||
components = list(components)
|
||||
|
||||
if not components:
|
||||
return True
|
||||
|
||||
current = components[0]
|
||||
|
||||
if current in embed:
|
||||
return path_exists(embed[current], components[1:])
|
||||
|
||||
return False
|
||||
|
||||
|
||||
async def fill_embed(embed: Embed) -> Embed:
|
||||
"""Fill an embed with more information."""
|
||||
embed = sanitize_embed(embed)
|
||||
|
||||
if path_exists(embed, 'footer.icon_url'):
|
||||
# TODO: make proxy_icon_url
|
||||
log.warning('embed with footer.icon_url, ignoring')
|
||||
|
||||
if path_exists(embed, 'image.url'):
|
||||
# TODO: make proxy_icon_url, width, height
|
||||
log.warning('embed with footer.image_url, ignoring')
|
||||
|
||||
if path_exists(embed, 'author.icon_url'):
|
||||
log.warning('embed with author.icon_url, ignoring')
|
||||
|
||||
return embed
|
||||
|
|
@ -20,6 +20,10 @@ class EmbedURL:
|
|||
"""Return the URL."""
|
||||
return urllib.parse.urlunparse(self.parsed)
|
||||
|
||||
@property
|
||||
def to_json(self):
|
||||
return self.url
|
||||
|
||||
|
||||
EMBED_FOOTER = {
|
||||
'text': {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import earl
|
|||
from litecord.auth import raw_token_check
|
||||
from litecord.enums import RelationshipType
|
||||
from litecord.schemas import validate, GW_STATUS_UPDATE
|
||||
from litecord.utils import task_wrapper
|
||||
from litecord.utils import task_wrapper, LitecordJSONEncoder
|
||||
from litecord.permissions import get_permissions
|
||||
|
||||
from litecord.gateway.opcodes import OP
|
||||
|
|
@ -39,7 +39,8 @@ WebsocketObjects = collections.namedtuple(
|
|||
|
||||
|
||||
def encode_json(payload) -> str:
|
||||
return json.dumps(payload, separators=(',', ':'))
|
||||
return json.dumps(payload, separators=(',', ':'),
|
||||
cls=LitecordJSONEncoder)
|
||||
|
||||
|
||||
def decode_json(data: str):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import asyncio
|
||||
import json
|
||||
from logbook import Logger
|
||||
from typing import Any
|
||||
from quart.json import JSONEncoder
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
|
@ -113,19 +115,27 @@ def mmh3(key: str, seed: int = 0):
|
|||
return _u(h1) >> 0
|
||||
|
||||
|
||||
class LitecordJSONEncoder(JSONEncoder):
|
||||
def default(self, value: Any):
|
||||
try:
|
||||
return value.to_json
|
||||
except AttributeError:
|
||||
return super().default(value)
|
||||
|
||||
|
||||
async def pg_set_json(con):
|
||||
"""Set JSON and JSONB codecs for an
|
||||
asyncpg connection."""
|
||||
await con.set_type_codec(
|
||||
'json',
|
||||
encoder=json.dumps,
|
||||
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||
decoder=json.loads,
|
||||
schema='pg_catalog'
|
||||
)
|
||||
|
||||
await con.set_type_codec(
|
||||
'jsonb',
|
||||
encoder=json.dumps,
|
||||
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||
decoder=json.loads,
|
||||
schema='pg_catalog'
|
||||
)
|
||||
|
|
|
|||
5
run.py
5
run.py
|
|
@ -50,6 +50,8 @@ from litecord.presence import PresenceManager
|
|||
from litecord.images import IconManager
|
||||
from litecord.jobs import JobManager
|
||||
|
||||
from litecord.utils import LitecordJSONEncoder
|
||||
|
||||
# setup logbook
|
||||
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
||||
handler.push_application()
|
||||
|
|
@ -71,6 +73,9 @@ def make_app():
|
|||
# always keep websockets on INFO
|
||||
logging.getLogger('websockets').setLevel(logbook.INFO)
|
||||
|
||||
# use our custom json encoder for custom data types
|
||||
app.json_encoder = LitecordJSONEncoder
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue