mirror of https://gitlab.com/litecord/litecord.git
blueprints.channels: support dms on message create
- storage: add Storage.get_dm_peer - gateway.websocket: subscribe to dms as well as guilds
This commit is contained in:
parent
efefb0cc2f
commit
20332805b8
|
|
@ -253,16 +253,49 @@ async def get_single_message(channel_id, message_id):
|
||||||
return jsonify(message)
|
return jsonify(message)
|
||||||
|
|
||||||
|
|
||||||
|
async def _dm_pre_dispatch(channel_id, peer_id):
|
||||||
|
"""Doo some checks pre-MESSAGE_CREATE so we
|
||||||
|
make sure the receiving party will handle everything."""
|
||||||
|
|
||||||
|
# check the other party's dm_channel_state
|
||||||
|
|
||||||
|
dm_state = await app.db.fetchval("""
|
||||||
|
SELECT dm_id
|
||||||
|
FROM dm_channel_state
|
||||||
|
WHERE user_id = $1 AND dm_id = $2
|
||||||
|
""", peer_id, channel_id)
|
||||||
|
|
||||||
|
if dm_state:
|
||||||
|
# the peer already has the channel
|
||||||
|
# opened, so we don't need to do anything
|
||||||
|
return
|
||||||
|
|
||||||
|
dm_chan = await app.storage.get_channel(channel_id)
|
||||||
|
|
||||||
|
# dispatch CHANNEL_CREATE so the client knows which
|
||||||
|
# channel the future event is about
|
||||||
|
await app.dispatcher.dispatch_user(peer_id, 'CHANNEL_CREATE', dm_chan)
|
||||||
|
|
||||||
|
# subscribe the peer to the channel
|
||||||
|
await app.dispatcher.sub('channel', channel_id, peer_id)
|
||||||
|
|
||||||
|
# insert it on dm_channel_state so the client
|
||||||
|
# is subscribed on the future
|
||||||
|
await app.db.execute("""
|
||||||
|
INSERT INTO dm_channel_state(user_id, dm_id)
|
||||||
|
VALUES ($1, $2)
|
||||||
|
""", peer_id, channel_id)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
@bp.route('/<int:channel_id>/messages', methods=['POST'])
|
||||||
async def create_message(channel_id):
|
async def create_message(channel_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
_ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
j = validate(await request.get_json(), MESSAGE_CREATE)
|
j = validate(await request.get_json(), MESSAGE_CREATE)
|
||||||
message_id = get_snowflake()
|
message_id = get_snowflake()
|
||||||
|
|
||||||
# TODO: check SEND_MESSAGES permission
|
# TODO: check SEND_MESSAGES permission
|
||||||
# TODO: check SEND_TTS_MESSAGES
|
|
||||||
# TODO: check connection to the gateway
|
# TODO: check connection to the gateway
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
|
|
@ -275,30 +308,36 @@ async def create_message(channel_id):
|
||||||
channel_id,
|
channel_id,
|
||||||
user_id,
|
user_id,
|
||||||
j['content'],
|
j['content'],
|
||||||
|
|
||||||
|
# TODO: check SEND_TTS_MESSAGES
|
||||||
j.get('tts', False),
|
j.get('tts', False),
|
||||||
|
|
||||||
|
# TODO: check MENTION_EVERYONE permissions
|
||||||
'@everyone' in j['content'],
|
'@everyone' in j['content'],
|
||||||
int(j.get('nonce', 0)),
|
int(j.get('nonce', 0)),
|
||||||
MessageType.DEFAULT.value
|
MessageType.DEFAULT.value
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: dispatch_channel
|
|
||||||
# we really need dispatch_channel to make dm messages work,
|
|
||||||
# since they aren't part of any existing guild.
|
|
||||||
payload = await app.storage.get_message(message_id)
|
payload = await app.storage.get_message(message_id)
|
||||||
|
|
||||||
|
if ctype == ChannelType.DM:
|
||||||
|
# guild id here is the peer's ID.
|
||||||
|
await _dm_pre_dispatch(channel_id, guild_id)
|
||||||
|
|
||||||
await app.dispatcher.dispatch('channel', channel_id,
|
await app.dispatcher.dispatch('channel', channel_id,
|
||||||
'MESSAGE_CREATE', payload)
|
'MESSAGE_CREATE', payload)
|
||||||
|
|
||||||
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
|
# TODO: dispatch the MESSAGE_CREATE to any mentioning user.
|
||||||
|
|
||||||
for str_uid in payload['mentions']:
|
if ctype == ChannelType.GUILD_TEXT:
|
||||||
uid = int(str_uid)
|
for str_uid in payload['mentions']:
|
||||||
|
uid = int(str_uid)
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute("""
|
||||||
UPDATE user_read_state
|
UPDATE user_read_state
|
||||||
SET mention_count += 1
|
SET mention_count += 1
|
||||||
WHERE user_id = $1 AND channel_id = $2
|
WHERE user_id = $1 AND channel_id = $2
|
||||||
""", uid, channel_id)
|
""", uid, channel_id)
|
||||||
|
|
||||||
return jsonify(payload)
|
return jsonify(payload)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,14 +37,5 @@ async def channel_check(user_id, channel_id):
|
||||||
return ctype, guild_id
|
return ctype, guild_id
|
||||||
|
|
||||||
if ctype == ChannelType.DM:
|
if ctype == ChannelType.DM:
|
||||||
parties = await app.db.fetchrow("""
|
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
||||||
SELECT party1_id, party2_id
|
return ctype, peer_id
|
||||||
FROM dm_channels
|
|
||||||
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
|
|
||||||
""", channel_id, user_id)
|
|
||||||
|
|
||||||
parties = [parties['party1_id'], parties['party2_id']]
|
|
||||||
|
|
||||||
# get the id of the other party
|
|
||||||
parties.remove(user_id)
|
|
||||||
return ctype, parties[0]
|
|
||||||
|
|
|
||||||
|
|
@ -289,8 +289,11 @@ async def get_library():
|
||||||
|
|
||||||
@bp.route('/<int:peer_id>/profile', methods=['GET'])
|
@bp.route('/<int:peer_id>/profile', methods=['GET'])
|
||||||
async def get_profile(peer_id: int):
|
async def get_profile(peer_id: int):
|
||||||
|
"""Get a user's profile."""
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
|
|
||||||
|
# TODO: check if they have any mutual guilds,
|
||||||
|
# and return empty profile if they don't.
|
||||||
peer = await app.storage.get_user(peer_id)
|
peer = await app.storage.get_user(peer_id)
|
||||||
|
|
||||||
if not peer:
|
if not peer:
|
||||||
|
|
|
||||||
|
|
@ -278,20 +278,28 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
async def _guild_ids(self):
|
async def _guild_ids(self):
|
||||||
# TODO: account for sharding
|
# TODO: account for sharding
|
||||||
guild_ids = await self.ext.db.fetch("""
|
return await self.storage.get_user_guilds(
|
||||||
SELECT guild_id
|
self.state.user_id
|
||||||
FROM members
|
)
|
||||||
WHERE user_id = $1
|
|
||||||
""", self.state.user_id)
|
|
||||||
|
|
||||||
return [r['guild_id'] for r in guild_ids]
|
|
||||||
|
|
||||||
async def subscribe_guilds(self):
|
async def subscribe_guilds(self):
|
||||||
"""Subscribe to all available guilds"""
|
"""Subscribe to all available guilds and DM channels.
|
||||||
|
|
||||||
|
Subscribing to channels is already handled
|
||||||
|
by GuildDispatcher.sub
|
||||||
|
"""
|
||||||
|
user_id = self.state.user_id
|
||||||
|
|
||||||
guild_ids = await self._guild_ids()
|
guild_ids = await self._guild_ids()
|
||||||
log.info('subscribing to {} guilds', len(guild_ids))
|
log.info('subscribing to {} guilds', len(guild_ids))
|
||||||
await self.ext.dispatcher.sub_many('guild',
|
await self.ext.dispatcher.sub_many('guild', user_id, guild_ids)
|
||||||
self.state.user_id, guild_ids)
|
|
||||||
|
# subscribe the user to all dms they have OPENED.
|
||||||
|
dms = await self.storage.get_dms(user_id)
|
||||||
|
dm_ids = [int(dm['id']) for dm in dms]
|
||||||
|
|
||||||
|
log.info('subscribing to {} dms', len(dm_ids))
|
||||||
|
await self.ext.dispatcher.sub_many('channel', user_id, dm_ids)
|
||||||
|
|
||||||
async def update_status(self, status: dict):
|
async def update_status(self, status: dict):
|
||||||
"""Update the status of the current websocket connection."""
|
"""Update the status of the current websocket connection."""
|
||||||
|
|
|
||||||
|
|
@ -847,3 +847,18 @@ class Storage:
|
||||||
FROM guild_channels
|
FROM guild_channels
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
""", channel_id)
|
""", channel_id)
|
||||||
|
|
||||||
|
async def get_dm_peer(self, channel_id: int, user_id: int) -> int:
|
||||||
|
"""Get the peer id on a dm"""
|
||||||
|
parties = await self.db.fetchrow("""
|
||||||
|
SELECT party1_id, party2_id
|
||||||
|
FROM dm_channels
|
||||||
|
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
|
||||||
|
""", channel_id, user_id)
|
||||||
|
|
||||||
|
parties = [parties['party1_id'], parties['party2_id']]
|
||||||
|
|
||||||
|
# get the id of the other party
|
||||||
|
parties.remove(user_id)
|
||||||
|
|
||||||
|
return parties[0]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue