blueprints.channels: support dms on message create

- storage: add Storage.get_dm_peer
 - gateway.websocket: subscribe to dms as well as guilds
This commit is contained in:
Luna Mendes 2018-10-11 23:01:56 -03:00
parent efefb0cc2f
commit 20332805b8
5 changed files with 89 additions and 33 deletions

View File

@ -253,16 +253,49 @@ async def get_single_message(channel_id, message_id):
return jsonify(message) return jsonify(message)
async def _dm_pre_dispatch(channel_id, peer_id):
"""Doo some checks pre-MESSAGE_CREATE so we
make sure the receiving party will handle everything."""
# check the other party's dm_channel_state
dm_state = await app.db.fetchval("""
SELECT dm_id
FROM dm_channel_state
WHERE user_id = $1 AND dm_id = $2
""", peer_id, channel_id)
if dm_state:
# the peer already has the channel
# opened, so we don't need to do anything
return
dm_chan = await app.storage.get_channel(channel_id)
# dispatch CHANNEL_CREATE so the client knows which
# channel the future event is about
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
# subscribe the peer to the channel
await app.dispatcher.sub('channel', channel_id, peer_id)
# insert it on dm_channel_state so the client
# is subscribed on the future
await app.db.execute("""
INSERT INTO dm_channel_state(user_id, dm_id)
VALUES ($1, $2)
""", peer_id, channel_id)
@bp.route('/<int:channel_id>/messages', methods=['POST']) @bp.route('/<int:channel_id>/messages', methods=['POST'])
async def create_message(channel_id): async def create_message(channel_id):
user_id = await token_check() user_id = await token_check()
_ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
j = validate(await request.get_json(), MESSAGE_CREATE) j = validate(await request.get_json(), MESSAGE_CREATE)
message_id = get_snowflake() message_id = get_snowflake()
# TODO: check SEND_MESSAGES permission # TODO: check SEND_MESSAGES permission
# TODO: check SEND_TTS_MESSAGES
# TODO: check connection to the gateway # TODO: check connection to the gateway
await app.db.execute( await app.db.execute(
@ -275,22 +308,28 @@ async def create_message(channel_id):
channel_id, channel_id,
user_id, user_id,
j['content'], j['content'],
# TODO: check SEND_TTS_MESSAGES
j.get('tts', False), j.get('tts', False),
# TODO: check MENTION_EVERYONE permissions
'@everyone' in j['content'], '@everyone' in j['content'],
int(j.get('nonce', 0)), int(j.get('nonce', 0)),
MessageType.DEFAULT.value MessageType.DEFAULT.value
) )
# TODO: dispatch_channel
# we really need dispatch_channel to make dm messages work,
# since they aren't part of any existing guild.
payload = await app.storage.get_message(message_id) payload = await app.storage.get_message(message_id)
if ctype == ChannelType.DM:
# guild id here is the peer's ID.
await _dm_pre_dispatch(channel_id, guild_id)
await app.dispatcher.dispatch('channel', channel_id, await app.dispatcher.dispatch('channel', channel_id,
'MESSAGE_CREATE', payload) 'MESSAGE_CREATE', payload)
# TODO: dispatch the MESSAGE_CREATE to any mentioning user. # TODO: dispatch the MESSAGE_CREATE to any mentioning user.
if ctype == ChannelType.GUILD_TEXT:
for str_uid in payload['mentions']: for str_uid in payload['mentions']:
uid = int(str_uid) uid = int(str_uid)

View File

@ -37,14 +37,5 @@ async def channel_check(user_id, channel_id):
return ctype, guild_id return ctype, guild_id
if ctype == ChannelType.DM: if ctype == ChannelType.DM:
parties = await app.db.fetchrow(""" peer_id = await app.storage.get_dm_peer(channel_id, user_id)
SELECT party1_id, party2_id return ctype, peer_id
FROM dm_channels
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
""", channel_id, user_id)
parties = [parties['party1_id'], parties['party2_id']]
# get the id of the other party
parties.remove(user_id)
return ctype, parties[0]

View File

@ -289,8 +289,11 @@ async def get_library():
@bp.route('/<int:peer_id>/profile', methods=['GET']) @bp.route('/<int:peer_id>/profile', methods=['GET'])
async def get_profile(peer_id: int): async def get_profile(peer_id: int):
"""Get a user's profile."""
user_id = await token_check() user_id = await token_check()
# TODO: check if they have any mutual guilds,
# and return empty profile if they don't.
peer = await app.storage.get_user(peer_id) peer = await app.storage.get_user(peer_id)
if not peer: if not peer:

View File

@ -278,20 +278,28 @@ class GatewayWebsocket:
async def _guild_ids(self): async def _guild_ids(self):
# TODO: account for sharding # TODO: account for sharding
guild_ids = await self.ext.db.fetch(""" return await self.storage.get_user_guilds(
SELECT guild_id self.state.user_id
FROM members )
WHERE user_id = $1
""", self.state.user_id)
return [r['guild_id'] for r in guild_ids]
async def subscribe_guilds(self): async def subscribe_guilds(self):
"""Subscribe to all available guilds""" """Subscribe to all available guilds and DM channels.
Subscribing to channels is already handled
by GuildDispatcher.sub
"""
user_id = self.state.user_id
guild_ids = await self._guild_ids() guild_ids = await self._guild_ids()
log.info('subscribing to {} guilds', len(guild_ids)) log.info('subscribing to {} guilds', len(guild_ids))
await self.ext.dispatcher.sub_many('guild', await self.ext.dispatcher.sub_many('guild', user_id, guild_ids)
self.state.user_id, guild_ids)
# subscribe the user to all dms they have OPENED.
dms = await self.storage.get_dms(user_id)
dm_ids = [int(dm['id']) for dm in dms]
log.info('subscribing to {} dms', len(dm_ids))
await self.ext.dispatcher.sub_many('channel', user_id, dm_ids)
async def update_status(self, status: dict): async def update_status(self, status: dict):
"""Update the status of the current websocket connection.""" """Update the status of the current websocket connection."""

View File

@ -847,3 +847,18 @@ class Storage:
FROM guild_channels FROM guild_channels
WHERE id = $1 WHERE id = $1
""", channel_id) """, channel_id)
async def get_dm_peer(self, channel_id: int, user_id: int) -> int:
"""Get the peer id on a dm"""
parties = await self.db.fetchrow("""
SELECT party1_id, party2_id
FROM dm_channels
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
""", channel_id, user_id)
parties = [parties['party1_id'], parties['party2_id']]
# get the id of the other party
parties.remove(user_id)
return parties[0]