mirror of https://gitlab.com/litecord/litecord.git
channel.messages: use sanitize_embed
- embed: add sanitizer module - embed.schemas: add to_json to EmbedURL - utils: add custom JSON encoder - run: use custom JSON encoder - gateway.websocket: use custom JSON encoder
This commit is contained in:
parent
5f6ddad54d
commit
5db633b797
|
|
@ -12,6 +12,8 @@ from litecord.snowflake import get_snowflake
|
||||||
from litecord.schemas import validate, MESSAGE_CREATE
|
from litecord.schemas import validate, MESSAGE_CREATE
|
||||||
from litecord.utils import pg_set_json
|
from litecord.utils import pg_set_json
|
||||||
|
|
||||||
|
from litecord.embed import sanitize_embed
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('channel_messages', __name__)
|
bp = Blueprint('channel_messages', __name__)
|
||||||
|
|
@ -258,7 +260,7 @@ async def _create_message(channel_id):
|
||||||
'tts': is_tts,
|
'tts': is_tts,
|
||||||
'nonce': int(j.get('nonce', 0)),
|
'nonce': int(j.get('nonce', 0)),
|
||||||
'everyone_mention': mentions_everyone or mentions_here,
|
'everyone_mention': mentions_everyone or mentions_here,
|
||||||
'embeds': [j['embed']] if 'embed' in j else [],
|
'embeds': [sanitize_embed(j['embed'])] if 'embed' in j else [],
|
||||||
})
|
})
|
||||||
|
|
||||||
payload = await app.storage.get_message(message_id, user_id)
|
payload = await app.storage.get_message(message_id, user_id)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .sanitizer import sanitize_embed
|
||||||
|
|
||||||
|
__all__ = ['sanitize_embed']
|
||||||
|
|
@ -0,0 +1,68 @@
|
||||||
|
"""
|
||||||
|
litecord.embed.sanitizer
|
||||||
|
sanitize embeds by giving common values
|
||||||
|
such as type: rich
|
||||||
|
"""
|
||||||
|
from typing import Dict, Any
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
from litecord.embed.schemas import EmbedURL
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
Embed = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def _sane(v):
|
||||||
|
if isinstance(v, EmbedURL):
|
||||||
|
return v.to_json
|
||||||
|
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_embed(embed: Embed) -> Embed:
|
||||||
|
"""Sanitize an embed object."""
|
||||||
|
return {**embed, **{
|
||||||
|
'type': 'rich'
|
||||||
|
}}
|
||||||
|
|
||||||
|
|
||||||
|
def path_exists(embed: Embed, components: str):
|
||||||
|
"""Tell if a given path exists in an embed.
|
||||||
|
|
||||||
|
The components string is formatted like this:
|
||||||
|
key1.key2.key3.key4. <...> .keyN
|
||||||
|
|
||||||
|
with each key going deeper and deeper into the embed.
|
||||||
|
"""
|
||||||
|
if isinstance(components, str):
|
||||||
|
components = components.split('.')
|
||||||
|
else:
|
||||||
|
components = list(components)
|
||||||
|
|
||||||
|
if not components:
|
||||||
|
return True
|
||||||
|
|
||||||
|
current = components[0]
|
||||||
|
|
||||||
|
if current in embed:
|
||||||
|
return path_exists(embed[current], components[1:])
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def fill_embed(embed: Embed) -> Embed:
|
||||||
|
"""Fill an embed with more information."""
|
||||||
|
embed = sanitize_embed(embed)
|
||||||
|
|
||||||
|
if path_exists(embed, 'footer.icon_url'):
|
||||||
|
# TODO: make proxy_icon_url
|
||||||
|
log.warning('embed with footer.icon_url, ignoring')
|
||||||
|
|
||||||
|
if path_exists(embed, 'image.url'):
|
||||||
|
# TODO: make proxy_icon_url, width, height
|
||||||
|
log.warning('embed with footer.image_url, ignoring')
|
||||||
|
|
||||||
|
if path_exists(embed, 'author.icon_url'):
|
||||||
|
log.warning('embed with author.icon_url, ignoring')
|
||||||
|
|
||||||
|
return embed
|
||||||
|
|
@ -20,6 +20,10 @@ class EmbedURL:
|
||||||
"""Return the URL."""
|
"""Return the URL."""
|
||||||
return urllib.parse.urlunparse(self.parsed)
|
return urllib.parse.urlunparse(self.parsed)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def to_json(self):
|
||||||
|
return self.url
|
||||||
|
|
||||||
|
|
||||||
EMBED_FOOTER = {
|
EMBED_FOOTER = {
|
||||||
'text': {
|
'text': {
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ import earl
|
||||||
from litecord.auth import raw_token_check
|
from litecord.auth import raw_token_check
|
||||||
from litecord.enums import RelationshipType
|
from litecord.enums import RelationshipType
|
||||||
from litecord.schemas import validate, GW_STATUS_UPDATE
|
from litecord.schemas import validate, GW_STATUS_UPDATE
|
||||||
from litecord.utils import task_wrapper
|
from litecord.utils import task_wrapper, LitecordJSONEncoder
|
||||||
from litecord.permissions import get_permissions
|
from litecord.permissions import get_permissions
|
||||||
|
|
||||||
from litecord.gateway.opcodes import OP
|
from litecord.gateway.opcodes import OP
|
||||||
|
|
@ -39,7 +39,8 @@ WebsocketObjects = collections.namedtuple(
|
||||||
|
|
||||||
|
|
||||||
def encode_json(payload) -> str:
|
def encode_json(payload) -> str:
|
||||||
return json.dumps(payload, separators=(',', ':'))
|
return json.dumps(payload, separators=(',', ':'),
|
||||||
|
cls=LitecordJSONEncoder)
|
||||||
|
|
||||||
|
|
||||||
def decode_json(data: str):
|
def decode_json(data: str):
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
from typing import Any
|
||||||
|
from quart.json import JSONEncoder
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -113,19 +115,27 @@ def mmh3(key: str, seed: int = 0):
|
||||||
return _u(h1) >> 0
|
return _u(h1) >> 0
|
||||||
|
|
||||||
|
|
||||||
|
class LitecordJSONEncoder(JSONEncoder):
|
||||||
|
def default(self, value: Any):
|
||||||
|
try:
|
||||||
|
return value.to_json
|
||||||
|
except AttributeError:
|
||||||
|
return super().default(value)
|
||||||
|
|
||||||
|
|
||||||
async def pg_set_json(con):
|
async def pg_set_json(con):
|
||||||
"""Set JSON and JSONB codecs for an
|
"""Set JSON and JSONB codecs for an
|
||||||
asyncpg connection."""
|
asyncpg connection."""
|
||||||
await con.set_type_codec(
|
await con.set_type_codec(
|
||||||
'json',
|
'json',
|
||||||
encoder=json.dumps,
|
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||||
decoder=json.loads,
|
decoder=json.loads,
|
||||||
schema='pg_catalog'
|
schema='pg_catalog'
|
||||||
)
|
)
|
||||||
|
|
||||||
await con.set_type_codec(
|
await con.set_type_codec(
|
||||||
'jsonb',
|
'jsonb',
|
||||||
encoder=json.dumps,
|
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||||
decoder=json.loads,
|
decoder=json.loads,
|
||||||
schema='pg_catalog'
|
schema='pg_catalog'
|
||||||
)
|
)
|
||||||
|
|
|
||||||
5
run.py
5
run.py
|
|
@ -50,6 +50,8 @@ from litecord.presence import PresenceManager
|
||||||
from litecord.images import IconManager
|
from litecord.images import IconManager
|
||||||
from litecord.jobs import JobManager
|
from litecord.jobs import JobManager
|
||||||
|
|
||||||
|
from litecord.utils import LitecordJSONEncoder
|
||||||
|
|
||||||
# setup logbook
|
# setup logbook
|
||||||
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
||||||
handler.push_application()
|
handler.push_application()
|
||||||
|
|
@ -71,6 +73,9 @@ def make_app():
|
||||||
# always keep websockets on INFO
|
# always keep websockets on INFO
|
||||||
logging.getLogger('websockets').setLevel(logbook.INFO)
|
logging.getLogger('websockets').setLevel(logbook.INFO)
|
||||||
|
|
||||||
|
# use our custom json encoder for custom data types
|
||||||
|
app.json_encoder = LitecordJSONEncoder
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue