Merge branch 'message-embeds' into 'master'

Message embeds

Closes #17

See merge request litecord/litecord!10
This commit is contained in:
Luna Mendes 2018-12-05 02:12:06 +00:00
commit d2bd6dd342
10 changed files with 429 additions and 96 deletions

View File

@ -10,6 +10,9 @@ from litecord.errors import MessageNotFound, Forbidden, BadRequest
from litecord.enums import MessageType, ChannelType, GUILD_CHANS from litecord.enums import MessageType, ChannelType, GUILD_CHANS
from litecord.snowflake import get_snowflake from litecord.snowflake import get_snowflake
from litecord.schemas import validate, MESSAGE_CREATE from litecord.schemas import validate, MESSAGE_CREATE
from litecord.utils import pg_set_json
from litecord.embed import sanitize_embed
log = Logger(__name__) log = Logger(__name__)
@ -139,8 +142,93 @@ async def _dm_pre_dispatch(channel_id, peer_id):
await try_dm_state(peer_id, channel_id) await try_dm_state(peer_id, channel_id)
async def create_message(channel_id: int, actual_guild_id: int,
author_id: int, data: dict) -> int:
message_id = get_snowflake()
async with app.db.acquire() as conn:
await pg_set_json(conn)
await conn.execute(
"""
INSERT INTO messages (id, channel_id, guild_id, author_id,
content, tts, mention_everyone, nonce, message_type,
embeds)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
""",
message_id,
channel_id,
actual_guild_id,
author_id,
data['content'],
data['tts'],
data['everyone_mention'],
data['nonce'],
MessageType.DEFAULT.value,
data.get('embeds', [])
)
return message_id
async def _guild_text_mentions(payload: dict, guild_id: int,
mentions_everyone: bool, mentions_here: bool):
channel_id = int(payload['channel_id'])
# calculate the user ids we'll bump the mention count for
uids = set()
# first is extracting user mentions
for mention in payload['mentions']:
uids.add(int(mention['id']))
# then role mentions
for role_mention in payload['mention_roles']:
role_id = int(role_mention)
member_ids = await app.storage.get_role_members(role_id)
for member_id in member_ids:
uids.add(member_id)
# at-here only updates the state
# for the users that have a state
# in the channel.
if mentions_here:
uids = []
await app.db.execute("""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1
""", channel_id)
# at-here updates the read state
# for all users, including the ones
# that might not have read permissions
# to the channel.
if mentions_everyone:
uids = []
member_ids = await app.storage.get_member_ids(guild_id)
await app.db.executemany("""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1 AND user_id = $2
""", [(channel_id, uid) for uid in member_ids])
for user_id in uids:
await app.db.execute("""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE user_id = $1
AND channel_id = $2
""", user_id, channel_id)
@bp.route('/<int:channel_id>/messages', methods=['POST']) @bp.route('/<int:channel_id>/messages', methods=['POST'])
async def create_message(channel_id): async def _create_message(channel_id):
"""Create a message."""
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
@ -151,7 +239,6 @@ async def create_message(channel_id):
actual_guild_id = guild_id actual_guild_id = guild_id
j = validate(await request.get_json(), MESSAGE_CREATE) j = validate(await request.get_json(), MESSAGE_CREATE)
message_id = get_snowflake()
# TODO: check connection to the gateway # TODO: check connection to the gateway
@ -167,24 +254,14 @@ async def create_message(channel_id):
user_id, channel_id, 'send_tts_messages', False user_id, channel_id, 'send_tts_messages', False
)) ))
await app.db.execute( message_id = await create_message(
""" channel_id, actual_guild_id, user_id, {
INSERT INTO messages (id, channel_id, guild_id, author_id, 'content': j['content'],
content, tts, mention_everyone, nonce, message_type) 'tts': is_tts,
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) 'nonce': int(j.get('nonce', 0)),
""", 'everyone_mention': mentions_everyone or mentions_here,
message_id, 'embeds': [sanitize_embed(j['embed'])] if 'embed' in j else [],
channel_id, })
actual_guild_id,
user_id,
j['content'],
is_tts,
mentions_everyone or mentions_here,
int(j.get('nonce', 0)),
MessageType.DEFAULT.value
)
payload = await app.storage.get_message(message_id, user_id) payload = await app.storage.get_message(message_id, user_id)
@ -196,6 +273,7 @@ async def create_message(channel_id):
await app.dispatcher.dispatch('channel', channel_id, await app.dispatcher.dispatch('channel', channel_id,
'MESSAGE_CREATE', payload) 'MESSAGE_CREATE', payload)
# update read state for the author
await app.db.execute(""" await app.db.execute("""
UPDATE user_read_state UPDATE user_read_state
SET last_message_id = $1 SET last_message_id = $1
@ -203,54 +281,8 @@ async def create_message(channel_id):
""", message_id, channel_id, user_id) """, message_id, channel_id, user_id)
if ctype == ChannelType.GUILD_TEXT: if ctype == ChannelType.GUILD_TEXT:
# calculate the user ids we'll bump the mention count for await _guild_text_mentions(payload, guild_id,
uids = set() mentions_everyone, mentions_here)
# first is extracting user mentions
for mention in payload['mentions']:
uids.add(int(mention['id']))
# then role mentions
for role_mention in payload['mention_roles']:
role_id = int(role_mention)
member_ids = await app.storage.get_role_members(role_id)
for member_id in member_ids:
uids.add(member_id)
# at-here only updates the state
# for the users that have a state
# in the channel.
if mentions_here:
uids = []
await app.db.execute("""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1
""", channel_id)
# at-here updates the read state
# for all users, including the ones
# that might not have read permissions
# to the channel.
if mentions_everyone:
uids = []
member_ids = await app.storage.get_member_ids(guild_id)
await app.db.executemany("""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE channel_id = $1 AND user_id = $2
""", [(channel_id, uid) for uid in member_ids])
for user_id in uids:
await app.db.execute("""
UPDATE user_read_state
SET mention_count = mention_count + 1
WHERE user_id = $1
AND channel_id = $2
""", user_id, channel_id)
return jsonify(payload) return jsonify(payload)

View File

@ -0,0 +1,3 @@
from .sanitizer import sanitize_embed
__all__ = ['sanitize_embed']

View File

@ -0,0 +1,73 @@
"""
litecord.embed.sanitizer
sanitize embeds by giving common values
such as type: rich
"""
from typing import Dict, Any
from logbook import Logger
log = Logger(__name__)
Embed = Dict[str, Any]
def sanitize_embed(embed: Embed) -> Embed:
"""Sanitize an embed object.
This is non-complex sanitization as it doesn't
need the app object.
"""
return {**embed, **{
'type': 'rich'
}}
def path_exists(embed: Embed, components: str):
"""Tell if a given path exists in an embed (or any dictionary).
The components string is formatted like this:
key1.key2.key3.key4. <...> .keyN
with each key going deeper and deeper into the embed.
"""
# get the list of components given
if isinstance(components, str):
components = components.split('.')
else:
components = list(components)
# if there are no components, we reached the end of recursion
# and can return true
if not components:
return True
# extract current component
current = components[0]
# if it exists, then we go down a level inside the dict
# (via recursion)
if current in embed:
return path_exists(embed[current], components[1:])
# if it doesn't exist, return False
return False
async def fill_embed(embed: Embed) -> Embed:
"""Fill an embed with more information."""
embed = sanitize_embed(embed)
if path_exists(embed, 'footer.icon_url'):
# TODO: make proxy_icon_url
log.warning('embed with footer.icon_url, ignoring')
if path_exists(embed, 'image.url'):
# TODO: make proxy_icon_url, width, height
log.warning('embed with footer.image_url, ignoring')
if path_exists(embed, 'author.icon_url'):
# TODO: should we check icon_url and convert it into
# a proxied icon url?
log.warning('embed with author.icon_url, ignoring')
return embed

121
litecord/embed/schemas.py Normal file
View File

@ -0,0 +1,121 @@
"""
litecord.embed.schemas - embed input validators.
"""
import urllib.parse
from litecord.types import Color
class EmbedURL:
def __init__(self, url: str):
parsed = urllib.parse.urlparse(url)
if parsed.scheme not in ('http', 'https', 'attachment'):
raise ValueError('Invalid URL scheme')
self.raw_url = url
self.parsed = parsed
@property
def url(self):
"""Return the URL."""
return urllib.parse.urlunparse(self.parsed)
@property
def to_json(self):
return self.url
EMBED_FOOTER = {
'text': {
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': True},
'icon_url': {
'coerce': EmbedURL, 'required': False,
},
# NOTE: proxy_icon_url set by us
}
EMBED_IMAGE = {
'url': {'coerce': EmbedURL, 'required': True},
# NOTE: proxy_url, width, height set by us
}
EMBED_THUMBNAIL = EMBED_IMAGE
EMBED_AUTHOR = {
'name': {
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False
},
'url': {
'coerce': EmbedURL, 'required': False,
},
'icon_url': {
'coerce': EmbedURL, 'required': False,
}
}
EMBED_FIELD = {
'name': {
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': True
},
'value': {
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': True
},
'inline': {
'type': 'boolean', 'required': False, 'default': True,
},
}
EMBED_OBJECT = {
'title': {
'type': 'string', 'minlength': 1, 'maxlength': 128, 'required': False},
# NOTE: type set by us
'description': {
'type': 'string', 'minlength': 1, 'maxlength': 1024, 'required': False,
},
'url': {
'coerce': EmbedURL, 'required': False,
},
'timestamp': {
# TODO: an ISO 8601 type
# TODO: maybe replace the default in here with now().isoformat?
'type': 'string', 'required': False
},
'color': {
'coerce': Color, 'required': False
},
'footer': {
'type': 'dict',
'schema': EMBED_FOOTER,
'required': False,
},
'image': {
'type': 'dict',
'schema': EMBED_IMAGE,
'required': False,
},
'thumbnail': {
'type': 'dict',
'schema': EMBED_THUMBNAIL,
'required': False,
},
# NOTE: 'video' set by us
# NOTE: 'provider' set by us
'author': {
'type': 'dict',
'schema': EMBED_AUTHOR,
'required': False,
},
'fields': {
'type': 'list',
'schema': {'type': 'dict', 'schema': EMBED_FIELD},
'required': False,
},
}

View File

@ -13,7 +13,7 @@ import earl
from litecord.auth import raw_token_check from litecord.auth import raw_token_check
from litecord.enums import RelationshipType from litecord.enums import RelationshipType
from litecord.schemas import validate, GW_STATUS_UPDATE from litecord.schemas import validate, GW_STATUS_UPDATE
from litecord.utils import task_wrapper from litecord.utils import task_wrapper, LitecordJSONEncoder
from litecord.permissions import get_permissions from litecord.permissions import get_permissions
from litecord.gateway.opcodes import OP from litecord.gateway.opcodes import OP
@ -39,7 +39,8 @@ WebsocketObjects = collections.namedtuple(
def encode_json(payload) -> str: def encode_json(payload) -> str:
return json.dumps(payload, separators=(',', ':')) return json.dumps(payload, separators=(',', ':'),
cls=LitecordJSONEncoder)
def decode_json(data: str): def decode_json(data: str):

View File

@ -12,6 +12,7 @@ from .enums import (
MessageNotifications, ChannelType, VerificationLevel MessageNotifications, ChannelType, VerificationLevel
) )
from litecord.embed.schemas import EMBED_OBJECT
log = Logger(__name__) log = Logger(__name__)
@ -372,6 +373,12 @@ MESSAGE_CREATE = {
'nonce': {'type': 'snowflake', 'required': False}, 'nonce': {'type': 'snowflake', 'required': False},
'tts': {'type': 'boolean', 'required': False}, 'tts': {'type': 'boolean', 'required': False},
'embed': {
'type': 'dict',
'schema': EMBED_OBJECT,
'required': False
}
# TODO: file, embed, payload_json # TODO: file, embed, payload_json
} }

View File

@ -1,4 +1,3 @@
import json
from typing import List, Dict, Any from typing import List, Dict, Any
from logbook import Logger from logbook import Logger
@ -12,6 +11,7 @@ from litecord.blueprints.channel.reactions import (
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
from litecord.types import timestamp_ from litecord.types import timestamp_
from litecord.utils import pg_set_json
log = Logger(__name__) log = Logger(__name__)
@ -33,24 +33,6 @@ def str_(val):
return maybe(str, val) return maybe(str, val)
async def _set_json(con):
"""Set JSON and JSONB codecs for an
asyncpg connection."""
await con.set_type_codec(
'json',
encoder=json.dumps,
decoder=json.loads,
schema='pg_catalog'
)
await con.set_type_codec(
'jsonb',
encoder=json.dumps,
decoder=json.loads,
schema='pg_catalog'
)
def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int): def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int):
"""Filter recipients in a list of recipients, removing """Filter recipients in a list of recipients, removing
the one that is reundant (ourselves).""" the one that is reundant (ourselves)."""
@ -73,13 +55,13 @@ class Storage:
# set_type_codec, so we must set it manually # set_type_codec, so we must set it manually
# by acquiring the connection # by acquiring the connection
async with self.db.acquire() as con: async with self.db.acquire() as con:
await _set_json(con) await pg_set_json(con)
return await con.fetchrow(query, *args) return await con.fetchrow(query, *args)
async def fetch_with_json(self, query: str, *args): async def fetch_with_json(self, query: str, *args):
"""Fetch many rows with JSON/JSONB support.""" """Fetch many rows with JSON/JSONB support."""
async with self.db.acquire() as con: async with self.db.acquire() as con:
await _set_json(con) await pg_set_json(con)
return await con.fetch(query, *args) return await con.fetch(query, *args)
async def get_user(self, user_id, secure=False) -> Dict[str, Any]: async def get_user(self, user_id, secure=False) -> Dict[str, Any]:
@ -631,10 +613,10 @@ class Storage:
async def get_message(self, message_id: int, user_id=None) -> Dict: async def get_message(self, message_id: int, user_id=None) -> Dict:
"""Get a single message's payload.""" """Get a single message's payload."""
row = await self.db.fetchrow(""" row = await self.fetchrow_with_json("""
SELECT id::text, channel_id::text, author_id, content, SELECT id::text, channel_id::text, author_id, content,
created_at AS timestamp, edited_at AS edited_timestamp, created_at AS timestamp, edited_at AS edited_timestamp,
tts, mention_everyone, nonce, message_type tts, mention_everyone, nonce, message_type, embeds
FROM messages FROM messages
WHERE id = $1 WHERE id = $1
""", message_id) """, message_id)
@ -698,9 +680,6 @@ class Storage:
# TODO: res['attachments'] # TODO: res['attachments']
res['attachments'] = [] res['attachments'] = []
# TODO: res['embeds']
res['embeds'] = []
# TODO: res['member'] for partial member data # TODO: res['member'] for partial member data
# of the author # of the author

View File

@ -1,5 +1,8 @@
import asyncio import asyncio
import json
from logbook import Logger from logbook import Logger
from typing import Any
from quart.json import JSONEncoder
log = Logger(__name__) log = Logger(__name__)
@ -110,3 +113,29 @@ def mmh3(key: str, seed: int = 0):
h1 ^= _u(h1) >> 16 h1 ^= _u(h1) >> 16
return _u(h1) >> 0 return _u(h1) >> 0
class LitecordJSONEncoder(JSONEncoder):
def default(self, value: Any):
try:
return value.to_json
except AttributeError:
return super().default(value)
async def pg_set_json(con):
"""Set JSON and JSONB codecs for an
asyncpg connection."""
await con.set_type_codec(
'json',
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema='pg_catalog'
)
await con.set_type_codec(
'jsonb',
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema='pg_catalog'
)

5
run.py
View File

@ -50,6 +50,8 @@ from litecord.presence import PresenceManager
from litecord.images import IconManager from litecord.images import IconManager
from litecord.jobs import JobManager from litecord.jobs import JobManager
from litecord.utils import LitecordJSONEncoder
# setup logbook # setup logbook
handler = StreamHandler(sys.stdout, level=logbook.INFO) handler = StreamHandler(sys.stdout, level=logbook.INFO)
handler.push_application() handler.push_application()
@ -71,6 +73,9 @@ def make_app():
# always keep websockets on INFO # always keep websockets on INFO
logging.getLogger('websockets').setLevel(logbook.INFO) logging.getLogger('websockets').setLevel(logbook.INFO)
# use our custom json encoder for custom data types
app.json_encoder = LitecordJSONEncoder
return app return app

83
tests/test_embeds.py Normal file
View File

@ -0,0 +1,83 @@
from litecord.schemas import validate
from litecord.embed.schemas import EMBED_OBJECT
def validate_embed(embed):
return validate(embed, EMBED_OBJECT)
def valid(embed: dict):
try:
validate_embed(embed)
return True
except:
return False
def invalid(embed):
try:
validate_embed(embed)
return False
except:
return True
def test_empty_embed():
valid({})
def test_basic_embed():
assert valid({
'title': 'test',
'description': 'acab',
'url': 'https://www.w3.org',
'color': 123
})
def test_footer_embed():
assert invalid({
'footer': {}
})
assert valid({
'title': 'test',
'footer': {
'text': 'abcdef'
}
})
def test_image():
assert invalid({
'image': {}
})
assert valid({
'image': {
'url': 'https://www.w3.org'
}
})
def test_author():
assert invalid({
'author': {
'name': ''
}
})
assert valid({
'author': {
'name': 'abcdef'
}
})
def test_fields():
assert valid({
'fields': [
{'name': 'a', 'value': 'b'},
{'name': 'c', 'value': 'd', 'inline': False},
]
})
valid({
'fields': [
{'name': 'a'},
]
})