blueprints.channels: add message sending

It is clunky when sending messages as Atomic, nor we have proper channel
management, but it works.

 - dispatcher: add sub_many
 - errors: add failsafe on LitecordError.message
 - errors: add ChannelNotFound
 - gateway.websocket: add dispatcher to WebsocketObjects
 - schemas: add regexes for mentions
 - storage: add get_channel, get_message
This commit is contained in:
Luna Mendes 2018-07-03 01:02:26 -03:00
parent 4ea3d353b3
commit 59127ad197
8 changed files with 238 additions and 46 deletions

View File

@ -1,39 +1,50 @@
import time
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
from ..auth import token_check
from ..snowflake import get_snowflake
from ..enums import ChannelType
from ..errors import Forbidden, BadRequest, MessageNotFound
from ..schemas import validate
from ..enums import ChannelType, MessageType
from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound
from ..schemas import validate, MESSAGE_CREATE
from .guilds import guild_check
log = Logger(__name__)
bp = Blueprint('channels', __name__)
async def channel_check(user_id, channel_id):
ctype = await app.db.fetchval("""
SELECT channel_type
FROM channels
WHERE channels.id = $1
""", channel_id)
"""Check if the current user is authorized
to read the channel's information."""
ctype = await app.storage.get_chan_type(channel_id)
if ctype is None:
raise ChannelNotFound(f'channel type not found')
if ctype in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY):
guild_id = await app.db.fetchval("""
SELECT guild_id
FROM guild_channels
WHERE channel_id = $1
WHERE guild_channels.id = $1
""", channel_id)
await guild_check(user_id, guild_id)
return guild_id
@bp.route('/<int:channel_id>', methods=['GET'])
async def get_channel(channel_id):
user_id = await token_check()
await channel_check(user_id, channel_id)
return '', 204
chan = await app.storage.get_channel(channel_id)
if not chan:
raise ChannelNotFound('single channel not found')
return jsonify(chan)
@bp.route('/<int:channel_id>/messages', methods=['GET'])
@ -43,13 +54,27 @@ async def get_messages(channel_id):
# TODO: before, after, around keys
await app.db.fetch(f"""
SELECT *
message_ids = await app.db.fetch(f"""
SELECT id
FROM messages
WHERE channel_id = $1
ORDER BY id ASC
ORDER BY id DESC
LIMIT 100
""")
""", channel_id)
result = []
for message_id in message_ids:
msg = await app.storage.get_message(message_id['id'])
if msg is None:
continue
result.append(msg)
log.info('Fetched {} messages', len(result))
print(result)
return jsonify(result)
@bp.route('/<int:channel_id>/messages/<int:message_id>', methods=['GET'])
@ -59,25 +84,50 @@ async def get_single_message(channel_id, message_id):
# TODO: check READ_MESSAGE_HISTORY permissions
message = await app.db.fetchrow("""
SELECT *
FROM messages
WHERE channel_id = $1 AND messages.id = $2
""", channel_id, message_id)
message = await app.storage.get_message(message_id)
if not message:
raise MessageNotFound()
return jsonify(message)
@bp.route('/<int:channel_id>/messages', methods=['POST'])
async def create_message(channel_id):
user_id = await token_check()
await channel_check(user_id, channel_id)
guild_id = await channel_check(user_id, channel_id)
j = validate(await request.get_json(), MESSAGE_CREATE)
message_id = get_snowflake()
# TODO: check SEND_MESSAGES permission
# TODO: check SEND_TTS_MESSAGES
# TODO: check connection to the gateway
await app.db.execute(
"""
INSERT INTO messages (id, channel_id, author_id, content, tts,
mention_everyone, nonce, message_type)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
""", message_id, channel_id, user_id, j['content'], j.get('tts', False),
'@everyone' in j['content'], j.get('nonce', 0), MessageType.DEFAULT)
# TODO: parse payload, make schema
# TODO: insert and dispatch message
# TODO: dispatch_channel
payload = await app.storage.get_message(message_id)
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload)
return jsonify(payload)
@bp.route('/<int:channel_id>/typing', methods=['POST'])
async def trigger_typing(channel_id):
user_id = await token_check()
guild_id = await channel_check(user_id, channel_id)
await app.dispatcher.dispatch_guild(guild_id, 'TYPING_START', {
'channel_id': channel_id,
'user_id': user_id,
'timestamp': int(time.time()),
})
return '', 204

View File

@ -18,7 +18,7 @@ async def guild_check(user_id: int, guild_id: int):
""", user_id, guild_id)
if not joined_at:
raise GuildNotFound()
raise GuildNotFound('guild not found')
async def guild_owner_check(user_id: int, guild_id: int):

View File

@ -24,6 +24,11 @@ class EventDispatcher:
"""Reset the guild bucket."""
self.guild_buckets[guild_id] = set()
def sub_many(self, user_id: int, guild_ids: list):
"""Subscribe to many guilds at a time."""
for guild_id in guild_ids:
self.sub_guild(guild_id, user_id)
async def dispatch_guild(self, guild_id: int,
event_name: str, event_payload: Any):
"""Dispatch an event to a guild"""

View File

@ -3,7 +3,10 @@ class LitecordError(Exception):
@property
def message(self):
return self.args[0]
try:
return self.args[0]
except IndexError:
return repr(self)
@property
def json(self):
@ -30,6 +33,10 @@ class GuildNotFound(LitecordError):
status_code = 404
class ChannelNotFound(LitecordError):
status_code = 404
class MessageNotFound(LitecordError):
status_code = 404

View File

@ -22,7 +22,7 @@ WebsocketProperties = collections.namedtuple(
)
WebsocketObjects = collections.namedtuple(
'WebsocketObjects', 'db state_manager storage loop'
'WebsocketObjects', 'db state_manager storage loop dispatcher'
)
@ -81,7 +81,6 @@ class GatewayWebsocket:
if not isinstance(encoded, bytes):
encoded = encoded.encode()
print(self.wsp.compress)
if self.wsp.compress == 'zlib-stream':
data1 = self.wsp.zctx.compress(encoded)
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
@ -246,6 +245,17 @@ class GatewayWebsocket:
if current_shard > shard_count:
raise InvalidShard('Shard count > Total shards')
async def subscribe_guilds(self):
"""Subscribe to all available guilds"""
guild_ids = await self.ext.db.fetch("""
SELECT guild_id
FROM members
WHERE user_id = $1
""", self.state.user_id)
guild_ids = [r['guild_id'] for r in guild_ids]
self.ext.dispatcher.sub_many(self.state.user_id, guild_ids)
async def handle_1(self, payload: Dict[str, Any]):
"""Handle OP 1 Heartbeat packets."""
pass
@ -291,6 +301,7 @@ class GatewayWebsocket:
self.ext.state_manager.insert(self.state)
await self.dispatch_ready()
await self.subscribe_guilds()
async def handle_3(self, payload: Dict[str, Any]):
"""Handle OP 3 Status Update."""
@ -386,6 +397,9 @@ class GatewayWebsocket:
log.warning('closed a client, state={} err={}', self.state, err)
await self.ws.close(code=err.code, reason=err.reason)
except Exception as err:
log.exception('An exception has occoured. state={}', self.state)
await self.ws.close(code=4000, reason=repr(err))
finally:
if self.state:
self.state.ws = None

View File

@ -8,6 +8,15 @@ USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A)
EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$',
re.A)
# collection of regexes
USER_MENTION = re.compile(r'<@!?(\d+)>', re.A | re.M)
CHAN_MENTION = re.compile(r'<#(\d+)>', re.A | re.M)
ROLE_MENTION = re.compile(r'<@&(\d+)>', re.A | re.M)
EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M)
ANIMOJI_MENTION = re.compile(r'<a:(\.+):(\d+)>', re.A | re.M)
class LitecordValidator(Validator):
def _validate_type_username(self, value: str) -> bool:
"""Validate against the username regex."""
@ -61,3 +70,11 @@ MEMBER_UPDATE = {
'deaf': {'type': 'bool', 'required': False},
'channel_id': {'type': 'snowflake', 'required': False},
}
MESSAGE_CREATE = {
'content': {'type': 'string', 'minlength': 1, 'maxlength': 2000},
'nonce': {'type': 'number', 'required': False},
'tts': {'type': 'boolean', 'required': False},
# TODO: file, embed, payload_json
}

View File

@ -1,6 +1,11 @@
from typing import List, Dict, Any
from .enums import ChannelType
from .schemas import USER_MENTION, ROLE_MENTION
async def _dummy(any_id):
return str(any_id)
class Storage:
@ -26,8 +31,8 @@ class Storage:
if not secure:
duser.pop('email')
duser.pop('mfa_enabled')
duser.pop('verified')
duser.pop('mfa_enabled')
return duser
@ -174,6 +179,51 @@ class Storage:
return {**row, **dict(vrow)}
async def get_chan_type(self, channel_id) -> int:
return await self.db.fetchval("""
SELECT channel_type
FROM channels
WHERE channels.id = $1
""", channel_id)
async def _chan_overwrites(self, channel_id):
overwrite_rows = await self.db.fetch("""
SELECT target_id::text AS id, overwrite_type, allow, deny
FROM channel_overwrites
WHERE channel_id = $1
""", channel_id)
def _overwrite_convert(ov_row):
drow = dict(ov_row)
drow['type'] = drow['overwrite_type']
drow.pop('overwrite_type')
return drow
return map(_overwrite_convert, overwrite_rows)
async def get_channel(self, channel_id) -> Dict[str, Any]:
"""Fetch a single channel's information."""
chan_type = await self.get_chan_type(channel_id)
if chan_type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY):
base = await self.db.fetchrow("""
SELECT id, guild_id::text, parent_id, name, position, nsfw
FROM guild_channels
WHERE guild_channels.id = $1
""", channel_id)
res = await self._channels_extra(dict(base), chan_type)
res['type'] = chan_type
res['permission_overwrites'] = \
list(await self._chan_overwrites(channel_id))
res['id'] = str(res['id'])
return res
else:
# TODO: dms and group dms
pass
async def get_channel_data(self, guild_id) -> List[Dict]:
"""Get channel information on a guild"""
channel_basics = await self.db.fetch("""
@ -193,22 +243,8 @@ class Storage:
res = await self._channels_extra(dict(row), ctype)
res['type'] = ctype
# type is a SQL keyword, so we can't do
# 'overwrite_type AS type'
overwrite_rows = await self.db.fetch("""
SELECT target_id::text AS id, overwrite_type, allow, deny
FROM channel_overwrites
WHERE channel_id = $1
""", row['id'])
def _overwrite_convert(ov_row):
drow = dict(ov_row)
drow['type'] = drow['overwrite_type']
drow.pop('overwrite_type')
return drow
res['permission_overwrites'] = list(map(_overwrite_convert,
overwrite_rows))
res['permission_overwrites'] = \
list(await self._chan_overwrites(row['id']))
# Making sure.
res['id'] = str(res['id'])
@ -262,6 +298,69 @@ class Storage:
'voice_states': [],
'channels': channels,
'roles': roles,
# TODO: finish those
# TODO: finish presences
'presences': [],
}}
async def _msg_regex(self, regex, method, content) -> List[Dict]:
res = []
for match in regex.finditer(content):
found_id = match.group(1)
try:
found_id = int(found_id)
except ValueError:
continue
obj = await method(found_id)
if obj:
res.append(obj)
return res
async def get_message(self, message_id: int) -> Dict:
"""Get a single message's payload."""
row = await self.db.fetchrow("""
SELECT id::text, channel_id::text, author_id, content,
created_at AS timestamp, edited_at AS edited_timestamp,
tts, mention_everyone, nonce, message_type
FROM messages
WHERE id = $1
""", message_id)
if not row:
return
res = dict(row)
res['timestamp'] = res['timestamp'].isoformat()
res['type'] = res['message_type']
res.pop('message_type')
# calculate user mentions and role mentions by regex
res['mentions'] = await self._msg_regex(USER_MENTION, self.get_user,
row['content'])
# _dummy just returns the string of the id, since we don't
# actually use the role objects in mention_roles, just their ids.
res['mention_roles'] = await self._msg_regex(ROLE_MENTION, _dummy,
row['content'])
# TODO: handle webhook authors
res['author'] = await self.get_user(res['author_id'])
res.pop('author_id')
# TODO: res['attachments']
res['attachments'] = []
# TODO: res['embeds']
res['embeds'] = []
# TODO: res['reactions']
res['reactions'] = []
# TODO: res['pinned']
res['pinned'] = False
return res

4
run.py
View File

@ -63,8 +63,8 @@ async def app_before_serving():
async def _wrapper(ws, url):
# We wrap the main websocket_handler
# so we can pass quart's app object.
await websocket_handler((app.db, app.state_manager,
app.storage, app.loop), ws, url)
await websocket_handler((app.db, app.state_manager, app.storage,
app.loop, app.dispatcher), ws, url)
ws_future = websockets.serve(_wrapper, host, port)