From 20332805b8314211fadc608afd43436e8af4f3d0 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Thu, 11 Oct 2018 23:01:56 -0300 Subject: [PATCH] blueprints.channels: support dms on message create - storage: add Storage.get_dm_peer - gateway.websocket: subscribe to dms as well as guilds --- litecord/blueprints/channels.py | 63 ++++++++++++++++++++++++++------- litecord/blueprints/checks.py | 13 ++----- litecord/blueprints/users.py | 3 ++ litecord/gateway/websocket.py | 28 +++++++++------ litecord/storage.py | 15 ++++++++ 5 files changed, 89 insertions(+), 33 deletions(-) diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 5bc0b84..3faa800 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -253,16 +253,49 @@ async def get_single_message(channel_id, message_id): return jsonify(message) +async def _dm_pre_dispatch(channel_id, peer_id): + """Doo some checks pre-MESSAGE_CREATE so we + make sure the receiving party will handle everything.""" + + # check the other party's dm_channel_state + + dm_state = await app.db.fetchval(""" + SELECT dm_id + FROM dm_channel_state + WHERE user_id = $1 AND dm_id = $2 + """, peer_id, channel_id) + + if dm_state: + # the peer already has the channel + # opened, so we don't need to do anything + return + + dm_chan = await app.storage.get_channel(channel_id) + + # dispatch CHANNEL_CREATE so the client knows which + # channel the future event is about + await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan) + + # subscribe the peer to the channel + await app.dispatcher.sub('channel', channel_id, peer_id) + + # insert it on dm_channel_state so the client + # is subscribed on the future + await app.db.execute(""" + INSERT INTO dm_channel_state(user_id, dm_id) + VALUES ($1, $2) + """, peer_id, channel_id) + + @bp.route('//messages', methods=['POST']) async def create_message(channel_id): user_id = await token_check() - _ctype, guild_id = await channel_check(user_id, channel_id) + ctype, guild_id = await channel_check(user_id, channel_id) j = validate(await request.get_json(), MESSAGE_CREATE) message_id = get_snowflake() # TODO: check SEND_MESSAGES permission - # TODO: check SEND_TTS_MESSAGES # TODO: check connection to the gateway await app.db.execute( @@ -275,30 +308,36 @@ async def create_message(channel_id): channel_id, user_id, j['content'], + + # TODO: check SEND_TTS_MESSAGES j.get('tts', False), + + # TODO: check MENTION_EVERYONE permissions '@everyone' in j['content'], int(j.get('nonce', 0)), MessageType.DEFAULT.value ) - # TODO: dispatch_channel - # we really need dispatch_channel to make dm messages work, - # since they aren't part of any existing guild. payload = await app.storage.get_message(message_id) + + if ctype == ChannelType.DM: + # guild id here is the peer's ID. + await _dm_pre_dispatch(channel_id, guild_id) await app.dispatcher.dispatch('channel', channel_id, 'MESSAGE_CREATE', payload) # TODO: dispatch the MESSAGE_CREATE to any mentioning user. - for str_uid in payload['mentions']: - uid = int(str_uid) + if ctype == ChannelType.GUILD_TEXT: + for str_uid in payload['mentions']: + uid = int(str_uid) - await app.db.execute(""" - UPDATE user_read_state - SET mention_count += 1 - WHERE user_id = $1 AND channel_id = $2 - """, uid, channel_id) + await app.db.execute(""" + UPDATE user_read_state + SET mention_count += 1 + WHERE user_id = $1 AND channel_id = $2 + """, uid, channel_id) return jsonify(payload) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 2639c72..5cfc225 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -37,14 +37,5 @@ async def channel_check(user_id, channel_id): return ctype, guild_id if ctype == ChannelType.DM: - parties = await app.db.fetchrow(""" - SELECT party1_id, party2_id - FROM dm_channels - WHERE id = $1 AND (party1_id = $2 OR party2_id = $2) - """, channel_id, user_id) - - parties = [parties['party1_id'], parties['party2_id']] - - # get the id of the other party - parties.remove(user_id) - return ctype, parties[0] + peer_id = await app.storage.get_dm_peer(channel_id, user_id) + return ctype, peer_id diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index e244365..a5670cc 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -289,8 +289,11 @@ async def get_library(): @bp.route('//profile', methods=['GET']) async def get_profile(peer_id: int): + """Get a user's profile.""" user_id = await token_check() + # TODO: check if they have any mutual guilds, + # and return empty profile if they don't. peer = await app.storage.get_user(peer_id) if not peer: diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 2b53cc1..fdcdd27 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -278,20 +278,28 @@ class GatewayWebsocket: async def _guild_ids(self): # TODO: account for sharding - guild_ids = await self.ext.db.fetch(""" - SELECT guild_id - FROM members - WHERE user_id = $1 - """, self.state.user_id) - - return [r['guild_id'] for r in guild_ids] + return await self.storage.get_user_guilds( + self.state.user_id + ) async def subscribe_guilds(self): - """Subscribe to all available guilds""" + """Subscribe to all available guilds and DM channels. + + Subscribing to channels is already handled + by GuildDispatcher.sub + """ + user_id = self.state.user_id + guild_ids = await self._guild_ids() log.info('subscribing to {} guilds', len(guild_ids)) - await self.ext.dispatcher.sub_many('guild', - self.state.user_id, guild_ids) + await self.ext.dispatcher.sub_many('guild', user_id, guild_ids) + + # subscribe the user to all dms they have OPENED. + dms = await self.storage.get_dms(user_id) + dm_ids = [int(dm['id']) for dm in dms] + + log.info('subscribing to {} dms', len(dm_ids)) + await self.ext.dispatcher.sub_many('channel', user_id, dm_ids) async def update_status(self, status: dict): """Update the status of the current websocket connection.""" diff --git a/litecord/storage.py b/litecord/storage.py index a1dc2d0..d52ff9c 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -847,3 +847,18 @@ class Storage: FROM guild_channels WHERE id = $1 """, channel_id) + + async def get_dm_peer(self, channel_id: int, user_id: int) -> int: + """Get the peer id on a dm""" + parties = await self.db.fetchrow(""" + SELECT party1_id, party2_id + FROM dm_channels + WHERE id = $1 AND (party1_id = $2 OR party2_id = $2) + """, channel_id, user_id) + + parties = [parties['party1_id'], parties['party2_id']] + + # get the id of the other party + parties.remove(user_id) + + return parties[0]