blueprints.channels: add message sending

It is clunky when sending messages as Atomic, nor we have proper channel
management, but it works.

 - dispatcher: add sub_many
 - errors: add failsafe on LitecordError.message
 - errors: add ChannelNotFound
 - gateway.websocket: add dispatcher to WebsocketObjects
 - schemas: add regexes for mentions
 - storage: add get_channel, get_message
This commit is contained in:
Luna Mendes 2018-07-03 01:02:26 -03:00
parent 4ea3d353b3
commit 59127ad197
8 changed files with 238 additions and 46 deletions

View File

@ -1,39 +1,50 @@
import time
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
from ..auth import token_check from ..auth import token_check
from ..snowflake import get_snowflake from ..snowflake import get_snowflake
from ..enums import ChannelType from ..enums import ChannelType, MessageType
from ..errors import Forbidden, BadRequest, MessageNotFound from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound
from ..schemas import validate from ..schemas import validate, MESSAGE_CREATE
from .guilds import guild_check from .guilds import guild_check
log = Logger(__name__)
bp = Blueprint('channels', __name__) bp = Blueprint('channels', __name__)
async def channel_check(user_id, channel_id): async def channel_check(user_id, channel_id):
ctype = await app.db.fetchval(""" """Check if the current user is authorized
SELECT channel_type to read the channel's information."""
FROM channels ctype = await app.storage.get_chan_type(channel_id)
WHERE channels.id = $1
""", channel_id) if ctype is None:
raise ChannelNotFound(f'channel type not found')
if ctype in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE, if ctype in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY): ChannelType.GUILD_CATEGORY):
guild_id = await app.db.fetchval(""" guild_id = await app.db.fetchval("""
SELECT guild_id SELECT guild_id
FROM guild_channels FROM guild_channels
WHERE channel_id = $1 WHERE guild_channels.id = $1
""", channel_id) """, channel_id)
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
return guild_id
@bp.route('/<int:channel_id>', methods=['GET']) @bp.route('/<int:channel_id>', methods=['GET'])
async def get_channel(channel_id): async def get_channel(channel_id):
user_id = await token_check() user_id = await token_check()
await channel_check(user_id, channel_id) await channel_check(user_id, channel_id)
return '', 204 chan = await app.storage.get_channel(channel_id)
if not chan:
raise ChannelNotFound('single channel not found')
return jsonify(chan)
@bp.route('/<int:channel_id>/messages', methods=['GET']) @bp.route('/<int:channel_id>/messages', methods=['GET'])
@ -43,13 +54,27 @@ async def get_messages(channel_id):
# TODO: before, after, around keys # TODO: before, after, around keys
await app.db.fetch(f""" message_ids = await app.db.fetch(f"""
SELECT * SELECT id
FROM messages FROM messages
WHERE channel_id = $1 WHERE channel_id = $1
ORDER BY id ASC ORDER BY id DESC
LIMIT 100 LIMIT 100
""") """, channel_id)
result = []
for message_id in message_ids:
msg = await app.storage.get_message(message_id['id'])
if msg is None:
continue
result.append(msg)
log.info('Fetched {} messages', len(result))
print(result)
return jsonify(result)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET']) @bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
@ -59,25 +84,50 @@ async def get_single_message(channel_id, message_id):
# TODO: check READ_MESSAGE_HISTORY permissions # TODO: check READ_MESSAGE_HISTORY permissions
message = await app.db.fetchrow(""" message = await app.storage.get_message(message_id)
SELECT *
FROM messages
WHERE channel_id = $1 AND messages.id = $2
""", channel_id, message_id)
if not message: if not message:
raise MessageNotFound() raise MessageNotFound()
return jsonify(message)
@bp.route('/<int:channel_id>/messages', methods=['POST']) @bp.route('/<int:channel_id>/messages', methods=['POST'])
async def create_message(channel_id): async def create_message(channel_id):
user_id = await token_check() user_id = await token_check()
await channel_check(user_id, channel_id) guild_id = await channel_check(user_id, channel_id)
j = validate(await request.get_json(), MESSAGE_CREATE)
message_id = get_snowflake()
# TODO: check SEND_MESSAGES permission # TODO: check SEND_MESSAGES permission
# TODO: check SEND_TTS_MESSAGES # TODO: check SEND_TTS_MESSAGES
# TODO: check connection to the gateway # TODO: check connection to the gateway
await app.db.execute(
"""
INSERT INTO messages (id, channel_id, author_id, content, tts,
mention_everyone, nonce, message_type)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
""", message_id, channel_id, user_id, j['content'], j.get('tts', False),
'@everyone' in j['content'], j.get('nonce', 0), MessageType.DEFAULT)
# TODO: parse payload, make schema # TODO: dispatch_channel
# TODO: insert and dispatch message payload = await app.storage.get_message(message_id)
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload)
return jsonify(payload)
@bp.route('/<int:channel_id>/typing', methods=['POST'])
async def trigger_typing(channel_id):
user_id = await token_check()
guild_id = await channel_check(user_id, channel_id)
await app.dispatcher.dispatch_guild(guild_id, 'TYPING_START', {
'channel_id': channel_id,
'user_id': user_id,
'timestamp': int(time.time()),
})
return '', 204

View File

@ -18,7 +18,7 @@ async def guild_check(user_id: int, guild_id: int):
""", user_id, guild_id) """, user_id, guild_id)
if not joined_at: if not joined_at:
raise GuildNotFound() raise GuildNotFound('guild not found')
async def guild_owner_check(user_id: int, guild_id: int): async def guild_owner_check(user_id: int, guild_id: int):

View File

@ -24,6 +24,11 @@ class EventDispatcher:
"""Reset the guild bucket.""" """Reset the guild bucket."""
self.guild_buckets[guild_id] = set() self.guild_buckets[guild_id] = set()
def sub_many(self, user_id: int, guild_ids: list):
"""Subscribe to many guilds at a time."""
for guild_id in guild_ids:
self.sub_guild(guild_id, user_id)
async def dispatch_guild(self, guild_id: int, async def dispatch_guild(self, guild_id: int,
event_name: str, event_payload: Any): event_name: str, event_payload: Any):
"""Dispatch an event to a guild""" """Dispatch an event to a guild"""

View File

@ -3,7 +3,10 @@ class LitecordError(Exception):
@property @property
def message(self): def message(self):
return self.args[0] try:
return self.args[0]
except IndexError:
return repr(self)
@property @property
def json(self): def json(self):
@ -30,6 +33,10 @@ class GuildNotFound(LitecordError):
status_code = 404 status_code = 404
class ChannelNotFound(LitecordError):
status_code = 404
class MessageNotFound(LitecordError): class MessageNotFound(LitecordError):
status_code = 404 status_code = 404

View File

@ -22,7 +22,7 @@ WebsocketProperties = collections.namedtuple(
) )
WebsocketObjects = collections.namedtuple( WebsocketObjects = collections.namedtuple(
'WebsocketObjects', 'db state_manager storage loop' 'WebsocketObjects', 'db state_manager storage loop dispatcher'
) )
@ -81,7 +81,6 @@ class GatewayWebsocket:
if not isinstance(encoded, bytes): if not isinstance(encoded, bytes):
encoded = encoded.encode() encoded = encoded.encode()
print(self.wsp.compress)
if self.wsp.compress == 'zlib-stream': if self.wsp.compress == 'zlib-stream':
data1 = self.wsp.zctx.compress(encoded) data1 = self.wsp.zctx.compress(encoded)
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
@ -246,6 +245,17 @@ class GatewayWebsocket:
if current_shard > shard_count: if current_shard > shard_count:
raise InvalidShard('Shard count > Total shards') raise InvalidShard('Shard count > Total shards')
async def subscribe_guilds(self):
"""Subscribe to all available guilds"""
guild_ids = await self.ext.db.fetch("""
SELECT guild_id
FROM members
WHERE user_id = $1
""", self.state.user_id)
guild_ids = [r['guild_id'] for r in guild_ids]
self.ext.dispatcher.sub_many(self.state.user_id, guild_ids)
async def handle_1(self, payload: Dict[str, Any]): async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets.""" """Handle OP 1 Heartbeat packets."""
pass pass
@ -291,6 +301,7 @@ class GatewayWebsocket:
self.ext.state_manager.insert(self.state) self.ext.state_manager.insert(self.state)
await self.dispatch_ready() await self.dispatch_ready()
await self.subscribe_guilds()
async def handle_3(self, payload: Dict[str, Any]): async def handle_3(self, payload: Dict[str, Any]):
"""Handle OP 3 Status Update.""" """Handle OP 3 Status Update."""
@ -386,6 +397,9 @@ class GatewayWebsocket:
log.warning('closed a client, state={} err={}', self.state, err) log.warning('closed a client, state={} err={}', self.state, err)
await self.ws.close(code=err.code, reason=err.reason) await self.ws.close(code=err.code, reason=err.reason)
except Exception as err:
log.exception('An exception has occoured. state={}', self.state)
await self.ws.close(code=4000, reason=repr(err))
finally: finally:
if self.state: if self.state:
self.state.ws = None self.state.ws = None

View File

@ -8,6 +8,15 @@ USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A)
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$', EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
re.A) re.A)
# collection of regexes
USER_MENTION = re.compile(r'<@!?(\d+)>', re.A | re.M)
CHAN_MENTION = re.compile(r'<#(\d+)>', re.A | re.M)
ROLE_MENTION = re.compile(r'<@&(\d+)>', re.A | re.M)
EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M)
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M)
class LitecordValidator(Validator): class LitecordValidator(Validator):
def _validate_type_username(self, value: str) -> bool: def _validate_type_username(self, value: str) -> bool:
"""Validate against the username regex.""" """Validate against the username regex."""
@ -61,3 +70,11 @@ MEMBER_UPDATE = {
'deaf': {'type': 'bool', 'required': False}, 'deaf': {'type': 'bool', 'required': False},
'channel_id': {'type': 'snowflake', 'required': False}, 'channel_id': {'type': 'snowflake', 'required': False},
} }
MESSAGE_CREATE = {
'content': {'type': 'string', 'minlength': 1, 'maxlength': 2000},
'nonce': {'type': 'number', 'required': False},
'tts': {'type': 'boolean', 'required': False},
# TODO: file, embed, payload_json
}

View File

@ -1,6 +1,11 @@
from typing import List, Dict, Any from typing import List, Dict, Any
from .enums import ChannelType from .enums import ChannelType
from .schemas import USER_MENTION, ROLE_MENTION
async def _dummy(any_id):
return str(any_id)
class Storage: class Storage:
@ -26,8 +31,8 @@ class Storage:
if not secure: if not secure:
duser.pop('email') duser.pop('email')
duser.pop('mfa_enabled')
duser.pop('verified') duser.pop('verified')
duser.pop('mfa_enabled')
return duser return duser
@ -174,6 +179,51 @@ class Storage:
return {**row, **dict(vrow)} return {**row, **dict(vrow)}
async def get_chan_type(self, channel_id) -> int:
return await self.db.fetchval("""
SELECT channel_type
FROM channels
WHERE channels.id = $1
""", channel_id)
async def _chan_overwrites(self, channel_id):
overwrite_rows = await self.db.fetch("""
SELECT target_id::text AS id, overwrite_type, allow, deny
FROM channel_overwrites
WHERE channel_id = $1
""", channel_id)
def _overwrite_convert(ov_row):
drow = dict(ov_row)
drow['type'] = drow['overwrite_type']
drow.pop('overwrite_type')
return drow
return map(_overwrite_convert, overwrite_rows)
async def get_channel(self, channel_id) -> Dict[str, Any]:
"""Fetch a single channel's information."""
chan_type = await self.get_chan_type(channel_id)
if chan_type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY):
base = await self.db.fetchrow("""
SELECT id, guild_id::text, parent_id, name, position, nsfw
FROM guild_channels
WHERE guild_channels.id = $1
""", channel_id)
res = await self._channels_extra(dict(base), chan_type)
res['type'] = chan_type
res['permission_overwrites'] = \
list(await self._chan_overwrites(channel_id))
res['id'] = str(res['id'])
return res
else:
# TODO: dms and group dms
pass
async def get_channel_data(self, guild_id) -> List[Dict]: async def get_channel_data(self, guild_id) -> List[Dict]:
"""Get channel information on a guild""" """Get channel information on a guild"""
channel_basics = await self.db.fetch(""" channel_basics = await self.db.fetch("""
@ -193,22 +243,8 @@ class Storage:
res = await self._channels_extra(dict(row), ctype) res = await self._channels_extra(dict(row), ctype)
res['type'] = ctype res['type'] = ctype
# type is a SQL keyword, so we can't do res['permission_overwrites'] = \
# 'overwrite_type AS type' list(await self._chan_overwrites(row['id']))
overwrite_rows = await self.db.fetch("""
SELECT target_id::text AS id, overwrite_type, allow, deny
FROM channel_overwrites
WHERE channel_id = $1
""", row['id'])
def _overwrite_convert(ov_row):
drow = dict(ov_row)
drow['type'] = drow['overwrite_type']
drow.pop('overwrite_type')
return drow
res['permission_overwrites'] = list(map(_overwrite_convert,
overwrite_rows))
# Making sure. # Making sure.
res['id'] = str(res['id']) res['id'] = str(res['id'])
@ -262,6 +298,69 @@ class Storage:
'voice_states': [], 'voice_states': [],
'channels': channels, 'channels': channels,
'roles': roles, 'roles': roles,
# TODO: finish those
# TODO: finish presences
'presences': [], 'presences': [],
}} }}
async def _msg_regex(self, regex, method, content) -> List[Dict]:
res = []
for match in regex.finditer(content):
found_id = match.group(1)
try:
found_id = int(found_id)
except ValueError:
continue
obj = await method(found_id)
if obj:
res.append(obj)
return res
async def get_message(self, message_id: int) -> Dict:
"""Get a single message's payload."""
row = await self.db.fetchrow("""
SELECT id::text, channel_id::text, author_id, content,
created_at AS timestamp, edited_at AS edited_timestamp,
tts, mention_everyone, nonce, message_type
FROM messages
WHERE id = $1
""", message_id)
if not row:
return
res = dict(row)
res['timestamp'] = res['timestamp'].isoformat()
res['type'] = res['message_type']
res.pop('message_type')
# calculate user mentions and role mentions by regex
res['mentions'] = await self._msg_regex(USER_MENTION, self.get_user,
row['content'])
# _dummy just returns the string of the id, since we don't
# actually use the role objects in mention_roles, just their ids.
res['mention_roles'] = await self._msg_regex(ROLE_MENTION, _dummy,
row['content'])
# TODO: handle webhook authors
res['author'] = await self.get_user(res['author_id'])
res.pop('author_id')
# TODO: res['attachments']
res['attachments'] = []
# TODO: res['embeds']
res['embeds'] = []
# TODO: res['reactions']
res['reactions'] = []
# TODO: res['pinned']
res['pinned'] = False
return res

4
run.py
View File

@ -63,8 +63,8 @@ async def app_before_serving():
async def _wrapper(ws, url): async def _wrapper(ws, url):
# We wrap the main websocket_handler # We wrap the main websocket_handler
# so we can pass quart's app object. # so we can pass quart's app object.
await websocket_handler((app.db, app.state_manager, await websocket_handler((app.db, app.state_manager, app.storage,
app.storage, app.loop), ws, url) app.loop, app.dispatcher), ws, url)
ws_future = websockets.serve(_wrapper, host, port) ws_future = websockets.serve(_wrapper, host, port)