mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'message-embeds' into 'master'
Message embeds Closes #17 See merge request litecord/litecord!10
This commit is contained in:
commit
d2bd6dd342
|
|
@ -10,6 +10,9 @@ from litecord.errors import MessageNotFound, Forbidden, BadRequest
|
||||||
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
||||||
from litecord.snowflake import get_snowflake
|
from litecord.snowflake import get_snowflake
|
||||||
from litecord.schemas import validate, MESSAGE_CREATE
|
from litecord.schemas import validate, MESSAGE_CREATE
|
||||||
|
from litecord.utils import pg_set_json
|
||||||
|
|
||||||
|
from litecord.embed import sanitize_embed
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
@ -139,8 +142,93 @@ async def _dm_pre_dispatch(channel_id, peer_id):
|
||||||
await try_dm_state(peer_id, channel_id)
|
await try_dm_state(peer_id, channel_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_message(channel_id: int, actual_guild_id: int,
|
||||||
|
author_id: int, data: dict) -> int:
|
||||||
|
message_id = get_snowflake()
|
||||||
|
|
||||||
|
async with app.db.acquire() as conn:
|
||||||
|
await pg_set_json(conn)
|
||||||
|
|
||||||
|
await conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO messages (id, channel_id, guild_id, author_id,
|
||||||
|
content, tts, mention_everyone, nonce, message_type,
|
||||||
|
embeds)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||||
|
""",
|
||||||
|
message_id,
|
||||||
|
channel_id,
|
||||||
|
actual_guild_id,
|
||||||
|
author_id,
|
||||||
|
data['content'],
|
||||||
|
|
||||||
|
data['tts'],
|
||||||
|
data['everyone_mention'],
|
||||||
|
|
||||||
|
data['nonce'],
|
||||||
|
MessageType.DEFAULT.value,
|
||||||
|
data.get('embeds', [])
|
||||||
|
)
|
||||||
|
|
||||||
|
return message_id
|
||||||
|
|
||||||
|
async def _guild_text_mentions(payload: dict, guild_id: int,
|
||||||
|
mentions_everyone: bool, mentions_here: bool):
|
||||||
|
channel_id = int(payload['channel_id'])
|
||||||
|
|
||||||
|
# calculate the user ids we'll bump the mention count for
|
||||||
|
uids = set()
|
||||||
|
|
||||||
|
# first is extracting user mentions
|
||||||
|
for mention in payload['mentions']:
|
||||||
|
uids.add(int(mention['id']))
|
||||||
|
|
||||||
|
# then role mentions
|
||||||
|
for role_mention in payload['mention_roles']:
|
||||||
|
role_id = int(role_mention)
|
||||||
|
member_ids = await app.storage.get_role_members(role_id)
|
||||||
|
|
||||||
|
for member_id in member_ids:
|
||||||
|
uids.add(member_id)
|
||||||
|
|
||||||
|
# at-here only updates the state
|
||||||
|
# for the users that have a state
|
||||||
|
# in the channel.
|
||||||
|
if mentions_here:
|
||||||
|
uids = []
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE user_read_state
|
||||||
|
SET mention_count = mention_count + 1
|
||||||
|
WHERE channel_id = $1
|
||||||
|
""", channel_id)
|
||||||
|
|
||||||
|
# at-here updates the read state
|
||||||
|
# for all users, including the ones
|
||||||
|
# that might not have read permissions
|
||||||
|
# to the channel.
|
||||||
|
if mentions_everyone:
|
||||||
|
uids = []
|
||||||
|
|
||||||
|
member_ids = await app.storage.get_member_ids(guild_id)
|
||||||
|
|
||||||
|
await app.db.executemany("""
|
||||||
|
UPDATE user_read_state
|
||||||
|
SET mention_count = mention_count + 1
|
||||||
|
WHERE channel_id = $1 AND user_id = $2
|
||||||
|
""", [(channel_id, uid) for uid in member_ids])
|
||||||
|
|
||||||
|
for user_id in uids:
|
||||||
|
await app.db.execute("""
|
||||||
|
UPDATE user_read_state
|
||||||
|
SET mention_count = mention_count + 1
|
||||||
|
WHERE user_id = $1
|
||||||
|
AND channel_id = $2
|
||||||
|
""", user_id, channel_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
||||||
async def create_message(channel_id):
|
async def _create_message(channel_id):
|
||||||
|
"""Create a message."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
|
@ -151,7 +239,6 @@ async def create_message(channel_id):
|
||||||
actual_guild_id = guild_id
|
actual_guild_id = guild_id
|
||||||
|
|
||||||
j = validate(await request.get_json(), MESSAGE_CREATE)
|
j = validate(await request.get_json(), MESSAGE_CREATE)
|
||||||
message_id = get_snowflake()
|
|
||||||
|
|
||||||
# TODO: check connection to the gateway
|
# TODO: check connection to the gateway
|
||||||
|
|
||||||
|
|
@ -167,24 +254,14 @@ async def create_message(channel_id):
|
||||||
user_id, channel_id, 'send_tts_messages', False
|
user_id, channel_id, 'send_tts_messages', False
|
||||||
))
|
))
|
||||||
|
|
||||||
await app.db.execute(
|
message_id = await create_message(
|
||||||
"""
|
channel_id, actual_guild_id, user_id, {
|
||||||
INSERT INTO messages (id, channel_id, guild_id, author_id,
|
'content': j['content'],
|
||||||
content, tts, mention_everyone, nonce, message_type)
|
'tts': is_tts,
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
'nonce': int(j.get('nonce', 0)),
|
||||||
""",
|
'everyone_mention': mentions_everyone or mentions_here,
|
||||||
message_id,
|
'embeds': [sanitize_embed(j['embed'])] if 'embed' in j else [],
|
||||||
channel_id,
|
})
|
||||||
actual_guild_id,
|
|
||||||
user_id,
|
|
||||||
j['content'],
|
|
||||||
|
|
||||||
is_tts,
|
|
||||||
mentions_everyone or mentions_here,
|
|
||||||
|
|
||||||
int(j.get('nonce', 0)),
|
|
||||||
MessageType.DEFAULT.value
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = await app.storage.get_message(message_id, user_id)
|
payload = await app.storage.get_message(message_id, user_id)
|
||||||
|
|
||||||
|
|
@ -196,6 +273,7 @@ async def create_message(channel_id):
|
||||||
await app.dispatcher.dispatch('channel', channel_id,
|
await app.dispatcher.dispatch('channel', channel_id,
|
||||||
'MESSAGE_CREATE', payload)
|
'MESSAGE_CREATE', payload)
|
||||||
|
|
||||||
|
# update read state for the author
|
||||||
await app.db.execute("""
|
await app.db.execute("""
|
||||||
UPDATE user_read_state
|
UPDATE user_read_state
|
||||||
SET last_message_id = $1
|
SET last_message_id = $1
|
||||||
|
|
@ -203,54 +281,8 @@ async def create_message(channel_id):
|
||||||
""", message_id, channel_id, user_id)
|
""", message_id, channel_id, user_id)
|
||||||
|
|
||||||
if ctype == ChannelType.GUILD_TEXT:
|
if ctype == ChannelType.GUILD_TEXT:
|
||||||
# calculate the user ids we'll bump the mention count for
|
await _guild_text_mentions(payload, guild_id,
|
||||||
uids = set()
|
mentions_everyone, mentions_here)
|
||||||
|
|
||||||
# first is extracting user mentions
|
|
||||||
for mention in payload['mentions']:
|
|
||||||
uids.add(int(mention['id']))
|
|
||||||
|
|
||||||
# then role mentions
|
|
||||||
for role_mention in payload['mention_roles']:
|
|
||||||
role_id = int(role_mention)
|
|
||||||
member_ids = await app.storage.get_role_members(role_id)
|
|
||||||
|
|
||||||
for member_id in member_ids:
|
|
||||||
uids.add(member_id)
|
|
||||||
|
|
||||||
# at-here only updates the state
|
|
||||||
# for the users that have a state
|
|
||||||
# in the channel.
|
|
||||||
if mentions_here:
|
|
||||||
uids = []
|
|
||||||
await app.db.execute("""
|
|
||||||
UPDATE user_read_state
|
|
||||||
SET mention_count = mention_count + 1
|
|
||||||
WHERE channel_id = $1
|
|
||||||
""", channel_id)
|
|
||||||
|
|
||||||
# at-here updates the read state
|
|
||||||
# for all users, including the ones
|
|
||||||
# that might not have read permissions
|
|
||||||
# to the channel.
|
|
||||||
if mentions_everyone:
|
|
||||||
uids = []
|
|
||||||
|
|
||||||
member_ids = await app.storage.get_member_ids(guild_id)
|
|
||||||
|
|
||||||
await app.db.executemany("""
|
|
||||||
UPDATE user_read_state
|
|
||||||
SET mention_count = mention_count + 1
|
|
||||||
WHERE channel_id = $1 AND user_id = $2
|
|
||||||
""", [(channel_id, uid) for uid in member_ids])
|
|
||||||
|
|
||||||
for user_id in uids:
|
|
||||||
await app.db.execute("""
|
|
||||||
UPDATE user_read_state
|
|
||||||
SET mention_count = mention_count + 1
|
|
||||||
WHERE user_id = $1
|
|
||||||
AND channel_id = $2
|
|
||||||
""", user_id, channel_id)
|
|
||||||
|
|
||||||
return jsonify(payload)
|
return jsonify(payload)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .sanitizer import sanitize_embed
|
||||||
|
|
||||||
|
__all__ = ['sanitize_embed']
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
"""
|
||||||
|
litecord.embed.sanitizer
|
||||||
|
sanitize embeds by giving common values
|
||||||
|
such as type: rich
|
||||||
|
"""
|
||||||
|
from typing import Dict, Any
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
Embed = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_embed(embed: Embed) -> Embed:
|
||||||
|
"""Sanitize an embed object.
|
||||||
|
|
||||||
|
This is non-complex sanitization as it doesn't
|
||||||
|
need the app object.
|
||||||
|
"""
|
||||||
|
return {**embed, **{
|
||||||
|
'type': 'rich'
|
||||||
|
}}
|
||||||
|
|
||||||
|
|
||||||
|
def path_exists(embed: Embed, components: str):
|
||||||
|
"""Tell if a given path exists in an embed (or any dictionary).
|
||||||
|
|
||||||
|
The components string is formatted like this:
|
||||||
|
key1.key2.key3.key4. <...> .keyN
|
||||||
|
|
||||||
|
with each key going deeper and deeper into the embed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# get the list of components given
|
||||||
|
if isinstance(components, str):
|
||||||
|
components = components.split('.')
|
||||||
|
else:
|
||||||
|
components = list(components)
|
||||||
|
|
||||||
|
# if there are no components, we reached the end of recursion
|
||||||
|
# and can return true
|
||||||
|
if not components:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# extract current component
|
||||||
|
current = components[0]
|
||||||
|
|
||||||
|
# if it exists, then we go down a level inside the dict
|
||||||
|
# (via recursion)
|
||||||
|
if current in embed:
|
||||||
|
return path_exists(embed[current], components[1:])
|
||||||
|
|
||||||
|
# if it doesn't exist, return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def fill_embed(embed: Embed) -> Embed:
|
||||||
|
"""Fill an embed with more information."""
|
||||||
|
embed = sanitize_embed(embed)
|
||||||
|
|
||||||
|
if path_exists(embed, 'footer.icon_url'):
|
||||||
|
# TODO: make proxy_icon_url
|
||||||
|
log.warning('embed with footer.icon_url, ignoring')
|
||||||
|
|
||||||
|
if path_exists(embed, 'image.url'):
|
||||||
|
# TODO: make proxy_icon_url, width, height
|
||||||
|
log.warning('embed with footer.image_url, ignoring')
|
||||||
|
|
||||||
|
if path_exists(embed, 'author.icon_url'):
|
||||||
|
# TODO: should we check icon_url and convert it into
|
||||||
|
# a proxied icon url?
|
||||||
|
log.warning('embed with author.icon_url, ignoring')
|
||||||
|
|
||||||
|
return embed
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
"""
|
||||||
|
litecord.embed.schemas - embed input validators.
|
||||||
|
"""
|
||||||
|
import urllib.parse
|
||||||
|
from litecord.types import Color
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedURL:
|
||||||
|
def __init__(self, url: str):
|
||||||
|
parsed = urllib.parse.urlparse(url)
|
||||||
|
|
||||||
|
if parsed.scheme not in ('http', 'https', 'attachment'):
|
||||||
|
raise ValueError('Invalid URL scheme')
|
||||||
|
|
||||||
|
self.raw_url = url
|
||||||
|
self.parsed = parsed
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self):
|
||||||
|
"""Return the URL."""
|
||||||
|
return urllib.parse.urlunparse(self.parsed)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def to_json(self):
|
||||||
|
return self.url
|
||||||
|
|
||||||
|
|
||||||
|
EMBED_FOOTER = {
|
||||||
|
'text': {
|
||||||
|
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': True},
|
||||||
|
|
||||||
|
'icon_url': {
|
||||||
|
'coerce': EmbedURL, 'required': False,
|
||||||
|
},
|
||||||
|
|
||||||
|
# NOTE: proxy_icon_url set by us
|
||||||
|
}
|
||||||
|
|
||||||
|
EMBED_IMAGE = {
|
||||||
|
'url': {'coerce': EmbedURL, 'required': True},
|
||||||
|
|
||||||
|
# NOTE: proxy_url, width, height set by us
|
||||||
|
}
|
||||||
|
|
||||||
|
EMBED_THUMBNAIL = EMBED_IMAGE
|
||||||
|
|
||||||
|
EMBED_AUTHOR = {
|
||||||
|
'name': {
|
||||||
|
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False
|
||||||
|
},
|
||||||
|
'url': {
|
||||||
|
'coerce': EmbedURL, 'required': False,
|
||||||
|
},
|
||||||
|
'icon_url': {
|
||||||
|
'coerce': EmbedURL, 'required': False,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EMBED_FIELD = {
|
||||||
|
'name': {
|
||||||
|
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': True
|
||||||
|
},
|
||||||
|
'value': {
|
||||||
|
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': True
|
||||||
|
},
|
||||||
|
'inline': {
|
||||||
|
'type': 'boolean', 'required': False, 'default': True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
EMBED_OBJECT = {
|
||||||
|
'title': {
|
||||||
|
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False},
|
||||||
|
# NOTE: type set by us
|
||||||
|
'description': {
|
||||||
|
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': False,
|
||||||
|
},
|
||||||
|
'url': {
|
||||||
|
'coerce': EmbedURL, 'required': False,
|
||||||
|
},
|
||||||
|
'timestamp': {
|
||||||
|
# TODO: an ISO 8601 type
|
||||||
|
# TODO: maybe replace the default in here with now().isoformat?
|
||||||
|
'type': 'string', 'required': False
|
||||||
|
},
|
||||||
|
|
||||||
|
'color': {
|
||||||
|
'coerce': Color, 'required': False
|
||||||
|
},
|
||||||
|
|
||||||
|
'footer': {
|
||||||
|
'type': 'dict',
|
||||||
|
'schema': EMBED_FOOTER,
|
||||||
|
'required': False,
|
||||||
|
},
|
||||||
|
'image': {
|
||||||
|
'type': 'dict',
|
||||||
|
'schema': EMBED_IMAGE,
|
||||||
|
'required': False,
|
||||||
|
},
|
||||||
|
'thumbnail': {
|
||||||
|
'type': 'dict',
|
||||||
|
'schema': EMBED_THUMBNAIL,
|
||||||
|
'required': False,
|
||||||
|
},
|
||||||
|
|
||||||
|
# NOTE: 'video' set by us
|
||||||
|
# NOTE: 'provider' set by us
|
||||||
|
|
||||||
|
'author': {
|
||||||
|
'type': 'dict',
|
||||||
|
'schema': EMBED_AUTHOR,
|
||||||
|
'required': False,
|
||||||
|
},
|
||||||
|
|
||||||
|
'fields': {
|
||||||
|
'type': 'list',
|
||||||
|
'schema': {'type': 'dict', 'schema': EMBED_FIELD},
|
||||||
|
'required': False,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
@ -13,7 +13,7 @@ import earl
|
||||||
from litecord.auth import raw_token_check
|
from litecord.auth import raw_token_check
|
||||||
from litecord.enums import RelationshipType
|
from litecord.enums import RelationshipType
|
||||||
from litecord.schemas import validate, GW_STATUS_UPDATE
|
from litecord.schemas import validate, GW_STATUS_UPDATE
|
||||||
from litecord.utils import task_wrapper
|
from litecord.utils import task_wrapper, LitecordJSONEncoder
|
||||||
from litecord.permissions import get_permissions
|
from litecord.permissions import get_permissions
|
||||||
|
|
||||||
from litecord.gateway.opcodes import OP
|
from litecord.gateway.opcodes import OP
|
||||||
|
|
@ -39,7 +39,8 @@ WebsocketObjects = collections.namedtuple(
|
||||||
|
|
||||||
|
|
||||||
def encode_json(payload) -> str:
|
def encode_json(payload) -> str:
|
||||||
return json.dumps(payload, separators=(',', ':'))
|
return json.dumps(payload, separators=(',', ':'),
|
||||||
|
cls=LitecordJSONEncoder)
|
||||||
|
|
||||||
|
|
||||||
def decode_json(data: str):
|
def decode_json(data: str):
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ from .enums import (
|
||||||
MessageNotifications, ChannelType, VerificationLevel
|
MessageNotifications, ChannelType, VerificationLevel
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from litecord.embed.schemas import EMBED_OBJECT
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -372,6 +373,12 @@ MESSAGE_CREATE = {
|
||||||
'nonce': {'type': 'snowflake', 'required': False},
|
'nonce': {'type': 'snowflake', 'required': False},
|
||||||
'tts': {'type': 'boolean', 'required': False},
|
'tts': {'type': 'boolean', 'required': False},
|
||||||
|
|
||||||
|
'embed': {
|
||||||
|
'type': 'dict',
|
||||||
|
'schema': EMBED_OBJECT,
|
||||||
|
'required': False
|
||||||
|
}
|
||||||
|
|
||||||
# TODO: file, embed, payload_json
|
# TODO: file, embed, payload_json
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
@ -12,6 +11,7 @@ from litecord.blueprints.channel.reactions import (
|
||||||
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
|
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
|
||||||
|
|
||||||
from litecord.types import timestamp_
|
from litecord.types import timestamp_
|
||||||
|
from litecord.utils import pg_set_json
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
@ -33,24 +33,6 @@ def str_(val):
|
||||||
return maybe(str, val)
|
return maybe(str, val)
|
||||||
|
|
||||||
|
|
||||||
async def _set_json(con):
|
|
||||||
"""Set JSON and JSONB codecs for an
|
|
||||||
asyncpg connection."""
|
|
||||||
await con.set_type_codec(
|
|
||||||
'json',
|
|
||||||
encoder=json.dumps,
|
|
||||||
decoder=json.loads,
|
|
||||||
schema='pg_catalog'
|
|
||||||
)
|
|
||||||
|
|
||||||
await con.set_type_codec(
|
|
||||||
'jsonb',
|
|
||||||
encoder=json.dumps,
|
|
||||||
decoder=json.loads,
|
|
||||||
schema='pg_catalog'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int):
|
def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int):
|
||||||
"""Filter recipients in a list of recipients, removing
|
"""Filter recipients in a list of recipients, removing
|
||||||
the one that is reundant (ourselves)."""
|
the one that is reundant (ourselves)."""
|
||||||
|
|
@ -73,13 +55,13 @@ class Storage:
|
||||||
# set_type_codec, so we must set it manually
|
# set_type_codec, so we must set it manually
|
||||||
# by acquiring the connection
|
# by acquiring the connection
|
||||||
async with self.db.acquire() as con:
|
async with self.db.acquire() as con:
|
||||||
await _set_json(con)
|
await pg_set_json(con)
|
||||||
return await con.fetchrow(query, *args)
|
return await con.fetchrow(query, *args)
|
||||||
|
|
||||||
async def fetch_with_json(self, query: str, *args):
|
async def fetch_with_json(self, query: str, *args):
|
||||||
"""Fetch many rows with JSON/JSONB support."""
|
"""Fetch many rows with JSON/JSONB support."""
|
||||||
async with self.db.acquire() as con:
|
async with self.db.acquire() as con:
|
||||||
await _set_json(con)
|
await pg_set_json(con)
|
||||||
return await con.fetch(query, *args)
|
return await con.fetch(query, *args)
|
||||||
|
|
||||||
async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
|
async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
|
||||||
|
|
@ -631,10 +613,10 @@ class Storage:
|
||||||
|
|
||||||
async def get_message(self, message_id: int, user_id=None) -> Dict:
|
async def get_message(self, message_id: int, user_id=None) -> Dict:
|
||||||
"""Get a single message's payload."""
|
"""Get a single message's payload."""
|
||||||
row = await self.db.fetchrow("""
|
row = await self.fetchrow_with_json("""
|
||||||
SELECT id::text, channel_id::text, author_id, content,
|
SELECT id::text, channel_id::text, author_id, content,
|
||||||
created_at AS timestamp, edited_at AS edited_timestamp,
|
created_at AS timestamp, edited_at AS edited_timestamp,
|
||||||
tts, mention_everyone, nonce, message_type
|
tts, mention_everyone, nonce, message_type, embeds
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", message_id)
|
""", message_id)
|
||||||
|
|
@ -698,9 +680,6 @@ class Storage:
|
||||||
# TODO: res['attachments']
|
# TODO: res['attachments']
|
||||||
res['attachments'] = []
|
res['attachments'] = []
|
||||||
|
|
||||||
# TODO: res['embeds']
|
|
||||||
res['embeds'] = []
|
|
||||||
|
|
||||||
# TODO: res['member'] for partial member data
|
# TODO: res['member'] for partial member data
|
||||||
# of the author
|
# of the author
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
from typing import Any
|
||||||
|
from quart.json import JSONEncoder
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -110,3 +113,29 @@ def mmh3(key: str, seed: int = 0):
|
||||||
h1 ^= _u(h1) >> 16
|
h1 ^= _u(h1) >> 16
|
||||||
|
|
||||||
return _u(h1) >> 0
|
return _u(h1) >> 0
|
||||||
|
|
||||||
|
|
||||||
|
class LitecordJSONEncoder(JSONEncoder):
|
||||||
|
def default(self, value: Any):
|
||||||
|
try:
|
||||||
|
return value.to_json
|
||||||
|
except AttributeError:
|
||||||
|
return super().default(value)
|
||||||
|
|
||||||
|
|
||||||
|
async def pg_set_json(con):
|
||||||
|
"""Set JSON and JSONB codecs for an
|
||||||
|
asyncpg connection."""
|
||||||
|
await con.set_type_codec(
|
||||||
|
'json',
|
||||||
|
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||||
|
decoder=json.loads,
|
||||||
|
schema='pg_catalog'
|
||||||
|
)
|
||||||
|
|
||||||
|
await con.set_type_codec(
|
||||||
|
'jsonb',
|
||||||
|
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||||
|
decoder=json.loads,
|
||||||
|
schema='pg_catalog'
|
||||||
|
)
|
||||||
|
|
|
||||||
5
run.py
5
run.py
|
|
@ -50,6 +50,8 @@ from litecord.presence import PresenceManager
|
||||||
from litecord.images import IconManager
|
from litecord.images import IconManager
|
||||||
from litecord.jobs import JobManager
|
from litecord.jobs import JobManager
|
||||||
|
|
||||||
|
from litecord.utils import LitecordJSONEncoder
|
||||||
|
|
||||||
# setup logbook
|
# setup logbook
|
||||||
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
handler = StreamHandler(sys.stdout, level=logbook.INFO)
|
||||||
handler.push_application()
|
handler.push_application()
|
||||||
|
|
@ -71,6 +73,9 @@ def make_app():
|
||||||
# always keep websockets on INFO
|
# always keep websockets on INFO
|
||||||
logging.getLogger('websockets').setLevel(logbook.INFO)
|
logging.getLogger('websockets').setLevel(logbook.INFO)
|
||||||
|
|
||||||
|
# use our custom json encoder for custom data types
|
||||||
|
app.json_encoder = LitecordJSONEncoder
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,83 @@
|
||||||
|
from litecord.schemas import validate
|
||||||
|
from litecord.embed.schemas import EMBED_OBJECT
|
||||||
|
|
||||||
|
def validate_embed(embed):
|
||||||
|
return validate(embed, EMBED_OBJECT)
|
||||||
|
|
||||||
|
def valid(embed: dict):
|
||||||
|
try:
|
||||||
|
validate_embed(embed)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def invalid(embed):
|
||||||
|
try:
|
||||||
|
validate_embed(embed)
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_empty_embed():
|
||||||
|
valid({})
|
||||||
|
|
||||||
|
|
||||||
|
def test_basic_embed():
|
||||||
|
assert valid({
|
||||||
|
'title': 'test',
|
||||||
|
'description': 'acab',
|
||||||
|
'url': 'https://www.w3.org',
|
||||||
|
'color': 123
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def test_footer_embed():
|
||||||
|
assert invalid({
|
||||||
|
'footer': {}
|
||||||
|
})
|
||||||
|
|
||||||
|
assert valid({
|
||||||
|
'title': 'test',
|
||||||
|
'footer': {
|
||||||
|
'text': 'abcdef'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_image():
|
||||||
|
assert invalid({
|
||||||
|
'image': {}
|
||||||
|
})
|
||||||
|
|
||||||
|
assert valid({
|
||||||
|
'image': {
|
||||||
|
'url': 'https://www.w3.org'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_author():
|
||||||
|
assert invalid({
|
||||||
|
'author': {
|
||||||
|
'name': ''
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
assert valid({
|
||||||
|
'author': {
|
||||||
|
'name': 'abcdef'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_fields():
|
||||||
|
assert valid({
|
||||||
|
'fields': [
|
||||||
|
{'name': 'a', 'value': 'b'},
|
||||||
|
{'name': 'c', 'value': 'd', 'inline': False},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
|
valid({
|
||||||
|
'fields': [
|
||||||
|
{'name': 'a'},
|
||||||
|
]
|
||||||
|
})
|
||||||
Loading…
Reference in New Issue