From 61e36f244bc6cf6e723ae37288e544cc8fca3caf Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Wed, 3 Oct 2018 21:43:16 -0300 Subject: [PATCH] blueprints.users, channels: basic dm operations SQL for instances: ```sql ALTER TABLE messages ADD CONSTRAINT messages_channels_fkey FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE; ALTER TABLE channel_pins ADD CONSTRAINT pins_channels_fkey FOREIGN KEY (channel_id) REFERENCES channels (id) ON DELETE CASCADE; ALTER TABLE channel_pins ADD CONSTRAINT pins_messages_fkey FOREIGN KEY (message_id) REFERENCES messages (id) ON DELETE CASCADE; ``` After that, rerun `schema.sql`. blueprints.channels: - check dms on channel_check - add DELETE /api/v6/channels/ blueprints.users: - add event dispatching for leaving guilds - add GET /api/v6/users/@me/channels, for DM fetching - add POST /api/v6/users/@me/channels, for DM creation - add POST /api/v6/users//channels for DM / Group DM creation - schemas: add CREATE, CREATE_GROUP_DM - storage: add last_message_id fetching for channels - storage: add support for DMs in get_channel - storage: add Storage.get_dm, Storage.get_dms, Storage.get_all_dms - schema.sql: add dm_channel_state table - schema.sql: add constriants for messages.channel_id and channel_pins --- litecord/blueprints/channels.py | 154 ++++++++++++++++++++++++++++++-- litecord/blueprints/users.py | 94 +++++++++++++++++-- litecord/enums.py | 5 ++ litecord/schemas.py | 13 +++ litecord/storage.py | 109 ++++++++++++++++++++-- schema.sql | 13 ++- 6 files changed, 364 insertions(+), 24 deletions(-) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index cebe3fb..80b6a6f 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -5,7 +5,7 @@ from logbook import Logger from ..auth import token_check from ..snowflake import get_snowflake, snowflake_datetime -from ..enums import ChannelType, MessageType +from ..enums import ChannelType, MessageType, GUILD_CHANS from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound from ..schemas import validate, MESSAGE_CREATE @@ -18,13 +18,14 @@ bp = Blueprint('channels', __name__) async def channel_check(user_id, channel_id): """Check if the current user is authorized to read the channel's information.""" - ctype = await app.storage.get_chan_type(channel_id) + chan_type = await app.storage.get_chan_type(channel_id) - if ctype is None: + if chan_type is None: raise ChannelNotFound(f'channel type not found') - if ChannelType(ctype) in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE, - ChannelType.GUILD_CATEGORY): + ctype = ChannelType(chan_type) + + if ctype in GUILD_CHANS: guild_id = await app.db.fetchval(""" SELECT guild_id FROM guild_channels @@ -34,10 +35,25 @@ async def channel_check(user_id, channel_id): await guild_check(user_id, guild_id) return guild_id + if ctype == ChannelType.DM: + parties = await app.db.fetchval(""" + SELECT party1_id, party2_id + FROM dm_channels + WHERE id = $1 AND (party1_id = $2 OR party2_id = $2) + """, channel_id, user_id) + + # get the id of the other party + parties.remove(user_id) + return parties[0] + @bp.route('/', methods=['GET']) async def get_channel(channel_id): + """Get a single channel's information""" user_id = await token_check() + + # channel_check takes care of checking + # DMs and group DMs await channel_check(user_id, channel_id) chan = await app.storage.get_channel(channel_id) @@ -47,6 +63,129 @@ async def get_channel(channel_id): return jsonify(chan) +async def __guild_chan_sql(guild_id, channel_id, field: str) -> str: + """Update a guild's channel id field to NULL, + if it was set to the given channel id before.""" + return await app.db.execute(f""" + UPDATE guilds + SET {field} = NULL + WHERE guilds.id = $1 AND {field} = $2 + """, guild_id, channel_id) + + +async def _update_guild_chan_text(guild_id: int, channel_id: int): + res_embed = await __guild_chan_sql( + guild_id, channel_id, 'embed_channel_id') + + res_widget = await __guild_chan_sql( + guild_id, channel_id, 'widget_channel_id') + + res_system = await __guild_chan_sql( + guild_id, channel_id, 'system_channel_id') + + # if none of them were actually updated, + # ignore and dont dispatch anything + if 'UPDATE 1' not in (res_embed, res_widget, res_system): + return + + # at least one of the fields were updated, + # dispatch GUILD_UPDATE + guild = await app.storage.get_guild(guild_id) + await app.dispatcher.dispatch_guild( + guild_id, 'GUILD_UPDATE', guild) + + +async def _update_guild_chan_voice(guild_id: int, channel_id: int): + res = await __guild_chan_sql(guild_id, channel_id, 'afk_channel_id') + + # guild didnt update + if res == 'UPDATE 0': + return + + guild = await app.storage.get_guild(guild_id) + await app.dispatcher.dispatch_guild( + guild_id, 'GUILD_UPDATE', guild) + + +async def _update_guild_chan_cat(guild_id: int, channel_id: int): + # get all channels that were childs of the category + childs = await app.db.fetch(""" + SELECT id + FROM guild_channels + WHERE guild_id = $1 AND parent_id = $2 + """, guild_id, channel_id) + childs = [c['id'] for c in childs] + + # update every child channel to parent_id = NULL + await app.db.execute(""" + UPDATE guild_channels + SET parent_id = NULL + WHERE guild_id = $1 AND parent_id = $2 + """, guild_id, channel_id) + + # tell all people in the guild of the category removal + for child_id in childs: + child = await app.storage.get_channel(child_id) + await app.dispatcher.dispatch_guild( + guild_id, 'CHANNEL_UPDATE', child + ) + + +@bp.route('/', methods=['DELETE']) +async def close_channel(channel_id): + user_id = await token_check() + + chan_type = await app.storage.get_chan_type(channel_id) + ctype = ChannelType(chan_type) + + if ctype in GUILD_CHANS: + guild_id = await channel_check(user_id, channel_id) + chan = await app.storage.get_channel(channel_id) + + # the selected function will take care of checking + # the sanity of tables once the channel becomes deleted. + _update_func = { + ChannelType.GUILD_TEXT: _update_guild_chan_text, + ChannelType.GUILD_VOICE: _update_guild_chan_voice, + ChannelType.GUILD_CATEGORY: _update_guild_chan_cat, + }[ctype] + + await _update_func(guild_id, channel_id) + + # this should take care of deleting all messages as well + # (if any) + await app.db.execute(""" + DELETE FROM guild_channels + WHERE id = $1 + """, channel_id) + + await app.dispatcher.dispatch_guild( + guild_id, 'CHANNEL_DELETE', chan) + return jsonify(chan) + + if ctype == ChannelType.DM: + chan = await app.storage.get_channel(channel_id) + + # we don't ever actually delete DM channels off the database. + # instead, we close the channel for the user that is making + # the request via removing the link between them and + # the channel on dm_channel_state + await app.db.execute(""" + DELETE FROM dm_channel_state (user_id, dm_id) + VALUES ($1, $2) + """, user_id, channel_id) + + # nothing happens to the other party of the dm channel + await app.dispacher.dispatch_user(user_id, 'CHANNEL_DELETE', chan) + return jsonify(chan) + + if ctype == ChannelType.GROUP_DM: + # TODO: group dm + pass + + return '', 404 + + @bp.route('//messages', methods=['GET']) async def get_messages(channel_id): user_id = await token_check() @@ -82,7 +221,6 @@ async def get_single_message(channel_id, message_id): await channel_check(user_id, channel_id) # TODO: check READ_MESSAGE_HISTORY permissions - message = await app.storage.get_message(message_id) if not message: @@ -120,6 +258,8 @@ async def create_message(channel_id): ) # TODO: dispatch_channel + # we really need dispatch_channel to make dm messages work, + # since they aren't part of any existing guild. payload = await app.storage.get_message(message_id) await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload) @@ -266,6 +406,7 @@ async def delete_pin(channel_id, message_id): timestamp = snowflake_datetime(row['message_id']) + # TODO: dispatch_channel await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', { 'channel_id': str(channel_id), 'last_pin_timestamp': timestamp.isoformat() @@ -279,6 +420,7 @@ async def trigger_typing(channel_id): user_id = await token_check() guild_id = await channel_check(user_id, channel_id) + # TODO: dispatch_channel await app.dispatcher.dispatch_guild(guild_id, 'TYPING_START', { 'channel_id': str(channel_id), 'user_id': str(user_id), diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 507b33b..02e83a4 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -2,8 +2,11 @@ from quart import Blueprint, jsonify, request, current_app as app from asyncpg import UniqueViolationError from ..auth import token_check +from ..snowflake import get_snowflake from ..errors import Forbidden, BadRequest -from ..schemas import validate, USER_SETTINGS +from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM + +from .guilds import guild_check bp = Blueprint('user', __name__) @@ -84,15 +87,31 @@ async def get_me_guilds(): @bp.route('/@me/guilds/', methods=['DELETE']) -async def leave_guild(guild_id): +async def leave_guild(guild_id: int): user_id = await token_check() + await guild_check(user_id, guild_id) await app.db.execute(""" DELETE FROM members WHERE user_id = $1 AND guild_id = $2 """, user_id, guild_id) - # TODO: something to dispatch events to the users + # first dispatch guild delete to the user, + # then remove from the guild, + # then tell the others that the member was removed + await app.dispatcher.dispatch_user_guild( + user_id, guild_id, 'GUILD_DELETE', { + 'id': str(guild_id), + 'unavailable': False, + } + ) + + await app.dispatcher.unsub_guild(guild_id, user_id) + + await app.dispatcher.dispatch_guild('GUILD_MEMBER_REMOVE', { + 'guild_id': str(guild_id), + 'user': await app.storage.get_user(user_id) + }) return '', 204 @@ -102,14 +121,75 @@ async def get_connections(): pass -# @bp.route('/@me/channels', methods=['GET']) +@bp.route('/@me/channels', methods=['GET']) async def get_dms(): - pass + user_id = await token_check() + dms = await app.storage.get_dms(user_id) + return jsonify(dms) -# @bp.route('/@me/channels', methods=['POST']) +async def try_dm_state(user_id, dm_id): + """Try insertin the user into the dm state + for the given DM.""" + try: + await app.db.execute(""" + INSERT INTO dm_channel_state (id, dm_id) + VALUES ($1, $2) + """, user_id, dm_id) + except UniqueViolationError: + # if already in state, ignore + pass + + +async def create_dm(user_id, recipient_id): + dm_id = get_snowflake() + + try: + await app.db.execute(""" + INSERT INTO dm_channels (id, party1_id, party2_id) + VALUES ($1, $2, $3) + """, dm_id, user_id, recipient_id) + + await try_dm_state(user_id, dm_id) + + except UniqueViolationError: + # the dm already exists + dm_id = await app.db.fetchval(""" + SELECT id + FROM dm_channels + WHERE (party1_id = $1 OR party2_id = $1) AND + (party2_id = $2 OR party2_id = $2) + """, user_id, recipient_id) + + dm = await app.storage.get_dm(dm_id, user_id) + return jsonify(dm) + + +@bp.route('/@me/channels', methods=['POST']) async def start_dm(): - pass + """Create a DM with a user.""" + user_id = await token_check() + j = validate(await request.get_json(), CREATE_DM) + recipient_id = j['recipient_id'] + + return await create_dm(user_id, recipient_id) + + +@bp.route('//channels', methods=['POST']) +async def create_group_dm(p_user_id: int): + """Create a DM or a Group DM with user(s).""" + user_id = await token_check() + assert user_id == p_user_id + + j = validate(await request.get_json(), CREATE_GROUP_DM) + recipients = j['recipients'] + + if list(recipients) == 1: + # its a group dm with 1 user... a dm! + return await create_dm(user_id, int(recipients[0])) + + # TODO: group dms + return '', 500 @bp.route('/@me/notes/', methods=['PUT']) diff --git a/litecord/enums.py b/litecord/enums.py index 2b26467..99cc0e2 100644 --- a/litecord/enums.py +++ b/litecord/enums.py @@ -17,6 +17,11 @@ class ChannelType(EasyEnum): GUILD_CATEGORY = 4 +GUILD_CHANS = (ChannelType.GUILD_TEXT, + ChannelType.GUILD_VOICE, + ChannelType.GUILD_CATEGORY) + + class ActivityType(EasyEnum): PLAYING = 0 STREAMING = 1 diff --git a/litecord/schemas.py b/litecord/schemas.py index cd89dca..d265e47 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -279,3 +279,16 @@ RELATIONSHIP = { 'default': RelationshipType.FRIEND.value } } + +CREATE_DM = { + 'recipient_id': { + 'type': 'snowflake', + 'required': True + } +} + +CREATE_GROUP_DM = { + 'type': 'list', + 'required': True, + 'schema': {'type': 'snowflake'} +} diff --git a/litecord/storage.py b/litecord/storage.py index 5557cf7..22b4c39 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -36,6 +36,16 @@ async def _set_json(con): ) +def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int): + """Filter recipients in a list of recipients, removing + the one that is reundant (ourselves).""" + user_id = str(user_id) + + return filter( + lambda recipient: recipient['id'] != user_id, + recipients) + + class Storage: """Class for common SQL statements.""" def __init__(self, db): @@ -215,12 +225,19 @@ class Storage: members = await self.get_member_multi(guild_id, mids) return members + async def _chan_last_message(self, channel_id: int): + return await self.db.fetch(""" + SELECT MAX(id) + FROM messages + WHERE channel_id = $1 + """, channel_id) + async def _channels_extra(self, row) -> Dict: """Fill in more information about a channel.""" channel_type = row['type'] - # TODO: dm and group dm? chan_type = ChannelType(channel_type) + if chan_type == ChannelType.GUILD_TEXT: topic = await self.db.fetchval(""" SELECT topic FROM guild_text_channels @@ -229,6 +246,8 @@ class Storage: return {**row, **{ 'topic': topic, + 'last_message_id': str( + await self._chan_last_message(row['id'])) }} elif chan_type == ChannelType.GUILD_VOICE: vrow = await self.db.fetchval(""" @@ -240,7 +259,8 @@ class Storage: log.warning('unknown channel type: {}', chan_type) - async def get_chan_type(self, channel_id) -> int: + async def get_chan_type(self, channel_id: int) -> int: + """Get the channel type integer, given channel ID.""" return await self.db.fetchval(""" SELECT channel_type FROM channels @@ -275,13 +295,14 @@ class Storage: return list(map(_overwrite_convert, overwrite_rows)) - async def get_channel(self, channel_id) -> Dict[str, Any]: + async def get_channel(self, channel_id: int) -> Dict[str, Any]: """Fetch a single channel's information.""" chan_type = await self.get_chan_type(channel_id) + ctype = ChannelType(chan_type) - if ChannelType(chan_type) in (ChannelType.GUILD_TEXT, - ChannelType.GUILD_VOICE, - ChannelType.GUILD_CATEGORY): + if ctype in (ChannelType.GUILD_TEXT, + ChannelType.GUILD_VOICE, + ChannelType.GUILD_CATEGORY): base = await self.db.fetchrow(""" SELECT id, guild_id::text, parent_id, name, position, nsfw FROM guild_channels @@ -297,10 +318,35 @@ class Storage: res['id'] = str(res['id']) return res - else: - # TODO: dms and group dms + elif ctype == ChannelType.DM: + dm_row = await self.db.fetchrow(""" + SELECT party1_id, party2_id + FROM dm_channels + WHERE id = $1 + """, channel_id) + + drow = dict(dm_row) + drow['type'] = chan_type + + drow['last_message_id'] = str( + await self._chan_last_message(channel_id)) + + # dms have just two recipients. + drow['recipients'] = [ + await self.get_user(drow['party1_id']), + await self.get_user(drow['party2_id']) + ] + + drow.pop('party1_id') + drow.pop('party2_id') + + drow['id'] = str(drow['id']) + return drow + elif ctype == ChannelType.GROUP_DM: pass + return None + async def get_channel_data(self, guild_id) -> List[Dict]: """Get channel information on a guild""" channel_basics = await self.db.fetch(""" @@ -574,6 +620,7 @@ class Storage: return dinv async def get_user_settings(self, user_id: int) -> Dict[str, Any]: + """Get current user settings.""" row = await self._fetchrow_with_json(""" SELECT * FROM user_settings @@ -688,3 +735,49 @@ class Storage: res.append(drow) return res + + async def get_dm(self, dm_id: int, user_id: int = None): + dm_chan = await self.get_channel(dm_id) + + if user_id: + dm_chan['recipients'] = _filter_recipients( + dm_chan['recipients'], user_id + ) + + return dm_chan + + async def get_dms(self, user_id: int) -> List[Dict[str, Any]]: + """Get all DM channels for a user, including group DMs. + + This will only fetch channels the user has in their state, + which is different than the whole list of DM channels. + """ + dm_ids = await self.db.fetch(""" + SELECT id + FROM dm_channel_state + WHERE user_id = $1 + """, user_id) + + res = [] + + for dm_id in dm_ids: + dm_chan = await self.get_dm(dm_id, user_id) + res.append(dm_chan) + + return res + + async def get_all_dms(self, user_id: int) -> List[Dict[str, Any]]: + """Get all DMs for a user, regardless of the DM state.""" + dm_ids = await self.db.fetch(""" + SELECT id + FROM dm_channels + WHERE party1_id = $1 OR party2_id = $2 + """, user_id) + + res = [] + + for dm_id in dm_ids: + dm_chan = await self.get_dm(dm_id, user_id) + res.append(dm_chan) + + return res diff --git a/schema.sql b/schema.sql index 54b4ea0..c1a0158 100644 --- a/schema.sql +++ b/schema.sql @@ -261,6 +261,13 @@ CREATE TABLE IF NOT EXISTS dm_channels ( ); +CREATE TABLE IF NOT EXISTS dm_channel_state ( + user_id bigint REFERENCES users (id) ON DELETE CASCADE, + dm_id bigint REFERENCES dm_channels (id) ON DELETE CASCADE, + PRIMARY KEY (user_id, dm_id) +); + + CREATE TABLE IF NOT EXISTS group_dm_channels ( id bigint REFERENCES channels (id) ON DELETE CASCADE, owner_id bigint REFERENCES users (id), @@ -440,7 +447,7 @@ CREATE TABLE IF NOT EXISTS embeds ( CREATE TABLE IF NOT EXISTS messages ( id bigint PRIMARY KEY, - channel_id bigint REFERENCES channels (id), + channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, -- those are mutually exclusive, only one of them -- can NOT be NULL at a time. @@ -486,7 +493,7 @@ CREATE TABLE IF NOT EXISTS message_reactions ( ); CREATE TABLE IF NOT EXISTS channel_pins ( - channel_id bigint REFERENCES channels (id), - message_id bigint REFERENCES messages (id), + channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, + message_id bigint REFERENCES messages (id) ON DELETE CASCADE, PRIMARY KEY (channel_id, message_id) );