diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index e1f2ea4..86ef572 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -1,39 +1,50 @@ +import time + from quart import Blueprint, request, current_app as app, jsonify +from logbook import Logger from ..auth import token_check from ..snowflake import get_snowflake -from ..enums import ChannelType -from ..errors import Forbidden, BadRequest, MessageNotFound -from ..schemas import validate +from ..enums import ChannelType, MessageType +from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound +from ..schemas import validate, MESSAGE_CREATE from .guilds import guild_check +log = Logger(__name__) bp = Blueprint('channels', __name__) async def channel_check(user_id, channel_id): - ctype = await app.db.fetchval(""" - SELECT channel_type - FROM channels - WHERE channels.id = $1 - """, channel_id) + """Check if the current user is authorized + to read the channel's information.""" + ctype = await app.storage.get_chan_type(channel_id) + + if ctype is None: + raise ChannelNotFound(f'channel type not found') if ctype in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE, ChannelType.GUILD_CATEGORY): guild_id = await app.db.fetchval(""" SELECT guild_id FROM guild_channels - WHERE channel_id = $1 + WHERE guild_channels.id = $1 """, channel_id) await guild_check(user_id, guild_id) + return guild_id @bp.route('/', methods=['GET']) async def get_channel(channel_id): user_id = await token_check() await channel_check(user_id, channel_id) - return '', 204 + chan = await app.storage.get_channel(channel_id) + + if not chan: + raise ChannelNotFound('single channel not found') + + return jsonify(chan) @bp.route('//messages', methods=['GET']) @@ -43,13 +54,27 @@ async def get_messages(channel_id): # TODO: before, after, around keys - await app.db.fetch(f""" - SELECT * + message_ids = await app.db.fetch(f""" + SELECT id FROM messages WHERE channel_id = $1 - ORDER BY id ASC + ORDER BY id DESC LIMIT 100 - """) + """, channel_id) + + result = [] + + for message_id in message_ids: + msg = await app.storage.get_message(message_id['id']) + + if msg is None: + continue + + result.append(msg) + + log.info('Fetched {} messages', len(result)) + print(result) + return jsonify(result) @bp.route('//messages/', methods=['GET']) @@ -59,25 +84,50 @@ async def get_single_message(channel_id, message_id): # TODO: check READ_MESSAGE_HISTORY permissions - message = await app.db.fetchrow(""" - SELECT * - FROM messages - WHERE channel_id = $1 AND messages.id = $2 - """, channel_id, message_id) + message = await app.storage.get_message(message_id) if not message: raise MessageNotFound() + return jsonify(message) + @bp.route('//messages', methods=['POST']) async def create_message(channel_id): user_id = await token_check() - await channel_check(user_id, channel_id) + guild_id = await channel_check(user_id, channel_id) + + j = validate(await request.get_json(), MESSAGE_CREATE) + message_id = get_snowflake() # TODO: check SEND_MESSAGES permission # TODO: check SEND_TTS_MESSAGES # TODO: check connection to the gateway + await app.db.execute( + """ + INSERT INTO messages (id, channel_id, author_id, content, tts, + mention_everyone, nonce, message_type) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + """, message_id, channel_id, user_id, j['content'], j.get('tts', False), + '@everyone' in j['content'], j.get('nonce', 0), MessageType.DEFAULT) - # TODO: parse payload, make schema - # TODO: insert and dispatch message + # TODO: dispatch_channel + payload = await app.storage.get_message(message_id) + await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload) + + return jsonify(payload) + + +@bp.route('//typing', methods=['POST']) +async def trigger_typing(channel_id): + user_id = await token_check() + guild_id = await channel_check(user_id, channel_id) + + await app.dispatcher.dispatch_guild(guild_id, 'TYPING_START', { + 'channel_id': channel_id, + 'user_id': user_id, + 'timestamp': int(time.time()), + }) + + return '', 204 diff --git a/litecord/blueprints/guilds.py b/litecord/blueprints/guilds.py index f9bea6f..4e01604 100644 --- a/litecord/blueprints/guilds.py +++ b/litecord/blueprints/guilds.py @@ -18,7 +18,7 @@ async def guild_check(user_id: int, guild_id: int): """, user_id, guild_id) if not joined_at: - raise GuildNotFound() + raise GuildNotFound('guild not found') async def guild_owner_check(user_id: int, guild_id: int): diff --git a/litecord/dispatcher.py b/litecord/dispatcher.py index 200ca47..e8f523e 100644 --- a/litecord/dispatcher.py +++ b/litecord/dispatcher.py @@ -24,6 +24,11 @@ class EventDispatcher: """Reset the guild bucket.""" self.guild_buckets[guild_id] = set() + def sub_many(self, user_id: int, guild_ids: list): + """Subscribe to many guilds at a time.""" + for guild_id in guild_ids: + self.sub_guild(guild_id, user_id) + async def dispatch_guild(self, guild_id: int, event_name: str, event_payload: Any): """Dispatch an event to a guild""" diff --git a/litecord/errors.py b/litecord/errors.py index 99a1a9d..fe4f130 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -3,7 +3,10 @@ class LitecordError(Exception): @property def message(self): - return self.args[0] + try: + return self.args[0] + except IndexError: + return repr(self) @property def json(self): @@ -30,6 +33,10 @@ class GuildNotFound(LitecordError): status_code = 404 +class ChannelNotFound(LitecordError): + status_code = 404 + + class MessageNotFound(LitecordError): status_code = 404 diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 9e22b6e..c849c54 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -22,7 +22,7 @@ WebsocketProperties = collections.namedtuple( ) WebsocketObjects = collections.namedtuple( - 'WebsocketObjects', 'db state_manager storage loop' + 'WebsocketObjects', 'db state_manager storage loop dispatcher' ) @@ -81,7 +81,6 @@ class GatewayWebsocket: if not isinstance(encoded, bytes): encoded = encoded.encode() - print(self.wsp.compress) if self.wsp.compress == 'zlib-stream': data1 = self.wsp.zctx.compress(encoded) data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) @@ -246,6 +245,17 @@ class GatewayWebsocket: if current_shard > shard_count: raise InvalidShard('Shard count > Total shards') + async def subscribe_guilds(self): + """Subscribe to all available guilds""" + guild_ids = await self.ext.db.fetch(""" + SELECT guild_id + FROM members + WHERE user_id = $1 + """, self.state.user_id) + + guild_ids = [r['guild_id'] for r in guild_ids] + self.ext.dispatcher.sub_many(self.state.user_id, guild_ids) + async def handle_1(self, payload: Dict[str, Any]): """Handle OP 1 Heartbeat packets.""" pass @@ -291,6 +301,7 @@ class GatewayWebsocket: self.ext.state_manager.insert(self.state) await self.dispatch_ready() + await self.subscribe_guilds() async def handle_3(self, payload: Dict[str, Any]): """Handle OP 3 Status Update.""" @@ -386,6 +397,9 @@ class GatewayWebsocket: log.warning('closed a client, state={} err={}', self.state, err) await self.ws.close(code=err.code, reason=err.reason) + except Exception as err: + log.exception('An exception has occoured. state={}', self.state) + await self.ws.close(code=4000, reason=repr(err)) finally: if self.state: self.state.ws = None diff --git a/litecord/schemas.py b/litecord/schemas.py index c66a5df..5900d80 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -8,6 +8,15 @@ USERNAME_REGEX = re.compile(r'^[a-zA-Z0-9_]{2,19}$', re.A) EMAIL_REGEX = re.compile(r'^[a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+$', re.A) + +# collection of regexes +USER_MENTION = re.compile(r'<@!?(\d+)>', re.A | re.M) +CHAN_MENTION = re.compile(r'<#(\d+)>', re.A | re.M) +ROLE_MENTION = re.compile(r'<@&(\d+)>', re.A | re.M) +EMOJO_MENTION = re.compile(r'<:(\.+):(\d+)>', re.A | re.M) +ANIMOJI_MENTION = re.compile(r'', re.A | re.M) + + class LitecordValidator(Validator): def _validate_type_username(self, value: str) -> bool: """Validate against the username regex.""" @@ -61,3 +70,11 @@ MEMBER_UPDATE = { 'deaf': {'type': 'bool', 'required': False}, 'channel_id': {'type': 'snowflake', 'required': False}, } + +MESSAGE_CREATE = { + 'content': {'type': 'string', 'minlength': 1, 'maxlength': 2000}, + 'nonce': {'type': 'number', 'required': False}, + 'tts': {'type': 'boolean', 'required': False}, + + # TODO: file, embed, payload_json +} diff --git a/litecord/storage.py b/litecord/storage.py index 42cf93e..facf669 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -1,6 +1,11 @@ from typing import List, Dict, Any from .enums import ChannelType +from .schemas import USER_MENTION, ROLE_MENTION + + +async def _dummy(any_id): + return str(any_id) class Storage: @@ -26,8 +31,8 @@ class Storage: if not secure: duser.pop('email') - duser.pop('mfa_enabled') duser.pop('verified') + duser.pop('mfa_enabled') return duser @@ -174,6 +179,51 @@ class Storage: return {**row, **dict(vrow)} + async def get_chan_type(self, channel_id) -> int: + return await self.db.fetchval(""" + SELECT channel_type + FROM channels + WHERE channels.id = $1 + """, channel_id) + + async def _chan_overwrites(self, channel_id): + overwrite_rows = await self.db.fetch(""" + SELECT target_id::text AS id, overwrite_type, allow, deny + FROM channel_overwrites + WHERE channel_id = $1 + """, channel_id) + + def _overwrite_convert(ov_row): + drow = dict(ov_row) + drow['type'] = drow['overwrite_type'] + drow.pop('overwrite_type') + return drow + + return map(_overwrite_convert, overwrite_rows) + + async def get_channel(self, channel_id) -> Dict[str, Any]: + """Fetch a single channel's information.""" + chan_type = await self.get_chan_type(channel_id) + + if chan_type in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE, + ChannelType.GUILD_CATEGORY): + base = await self.db.fetchrow(""" + SELECT id, guild_id::text, parent_id, name, position, nsfw + FROM guild_channels + WHERE guild_channels.id = $1 + """, channel_id) + + res = await self._channels_extra(dict(base), chan_type) + res['type'] = chan_type + res['permission_overwrites'] = \ + list(await self._chan_overwrites(channel_id)) + + res['id'] = str(res['id']) + return res + else: + # TODO: dms and group dms + pass + async def get_channel_data(self, guild_id) -> List[Dict]: """Get channel information on a guild""" channel_basics = await self.db.fetch(""" @@ -193,22 +243,8 @@ class Storage: res = await self._channels_extra(dict(row), ctype) res['type'] = ctype - # type is a SQL keyword, so we can't do - # 'overwrite_type AS type' - overwrite_rows = await self.db.fetch(""" - SELECT target_id::text AS id, overwrite_type, allow, deny - FROM channel_overwrites - WHERE channel_id = $1 - """, row['id']) - - def _overwrite_convert(ov_row): - drow = dict(ov_row) - drow['type'] = drow['overwrite_type'] - drow.pop('overwrite_type') - return drow - - res['permission_overwrites'] = list(map(_overwrite_convert, - overwrite_rows)) + res['permission_overwrites'] = \ + list(await self._chan_overwrites(row['id'])) # Making sure. res['id'] = str(res['id']) @@ -262,6 +298,69 @@ class Storage: 'voice_states': [], 'channels': channels, 'roles': roles, - # TODO: finish those + + # TODO: finish presences 'presences': [], }} + + async def _msg_regex(self, regex, method, content) -> List[Dict]: + res = [] + + for match in regex.finditer(content): + found_id = match.group(1) + + try: + found_id = int(found_id) + except ValueError: + continue + + obj = await method(found_id) + if obj: + res.append(obj) + + return res + + async def get_message(self, message_id: int) -> Dict: + """Get a single message's payload.""" + row = await self.db.fetchrow(""" + SELECT id::text, channel_id::text, author_id, content, + created_at AS timestamp, edited_at AS edited_timestamp, + tts, mention_everyone, nonce, message_type + FROM messages + WHERE id = $1 + """, message_id) + + if not row: + return + + res = dict(row) + res['timestamp'] = res['timestamp'].isoformat() + res['type'] = res['message_type'] + res.pop('message_type') + + # calculate user mentions and role mentions by regex + res['mentions'] = await self._msg_regex(USER_MENTION, self.get_user, + row['content']) + + # _dummy just returns the string of the id, since we don't + # actually use the role objects in mention_roles, just their ids. + res['mention_roles'] = await self._msg_regex(ROLE_MENTION, _dummy, + row['content']) + + # TODO: handle webhook authors + res['author'] = await self.get_user(res['author_id']) + res.pop('author_id') + + # TODO: res['attachments'] + res['attachments'] = [] + + # TODO: res['embeds'] + res['embeds'] = [] + + # TODO: res['reactions'] + res['reactions'] = [] + + # TODO: res['pinned'] + res['pinned'] = False + + return res diff --git a/run.py b/run.py index b00c022..b3794a6 100644 --- a/run.py +++ b/run.py @@ -63,8 +63,8 @@ async def app_before_serving(): async def _wrapper(ws, url): # We wrap the main websocket_handler # so we can pass quart's app object. - await websocket_handler((app.db, app.state_manager, - app.storage, app.loop), ws, url) + await websocket_handler((app.db, app.state_manager, app.storage, + app.loop, app.dispatcher), ws, url) ws_future = websockets.serve(_wrapper, host, port)