blueprints.users, channels: basic dm operations

SQL for instances:

```sql
ALTER TABLE messages
ADD CONSTRAINT messages_channels_fkey
FOREIGN KEY (channel_id)
REFERENCES channels (id)
ON DELETE CASCADE;

ALTER TABLE channel_pins ADD CONSTRAINT pins_channels_fkey
FOREIGN KEY (channel_id)
REFERENCES channels (id)
ON DELETE CASCADE;

ALTER TABLE channel_pins ADD CONSTRAINT pins_messages_fkey
FOREIGN KEY (message_id)
REFERENCES messages (id)
ON DELETE CASCADE;
```

After that, rerun `schema.sql`.

blueprints.channels:
 - check dms on channel_check
 - add DELETE /api/v6/channels/<int:channel_id>

blueprints.users:
 - add event dispatching for leaving guilds
 - add GET /api/v6/users/@me/channels, for DM fetching
 - add POST /api/v6/users/@me/channels, for DM creation
 - add POST /api/v6/users/<int:user_id>/channels for DM / Group DM
    creation

 - schemas: add CREATE, CREATE_GROUP_DM
 - storage: add last_message_id fetching for channels
 - storage: add support for DMs in get_channel
 - storage: add Storage.get_dm, Storage.get_dms, Storage.get_all_dms
 - schema.sql: add dm_channel_state table
 - schema.sql: add constriants for messages.channel_id and channel_pins
This commit is contained in:
Luna Mendes 2018-10-03 21:43:16 -03:00
parent 88bc4ceca8
commit 61e36f244b
6 changed files with 364 additions and 24 deletions

View File

@ -5,7 +5,7 @@ from logbook import Logger
from ..auth import token_check from ..auth import token_check
from ..snowflake import get_snowflake, snowflake_datetime from ..snowflake import get_snowflake, snowflake_datetime
from ..enums import ChannelType, MessageType from ..enums import ChannelType, MessageType, GUILD_CHANS
from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound from ..errors import Forbidden, BadRequest, ChannelNotFound, MessageNotFound
from ..schemas import validate, MESSAGE_CREATE from ..schemas import validate, MESSAGE_CREATE
@ -18,13 +18,14 @@ bp = Blueprint('channels', __name__)
async def channel_check(user_id, channel_id): async def channel_check(user_id, channel_id):
"""Check if the current user is authorized """Check if the current user is authorized
to read the channel's information.""" to read the channel's information."""
ctype = await app.storage.get_chan_type(channel_id) chan_type = await app.storage.get_chan_type(channel_id)
if ctype is None: if chan_type is None:
raise ChannelNotFound(f'channel type not found') raise ChannelNotFound(f'channel type not found')
if ChannelType(ctype) in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE, ctype = ChannelType(chan_type)
ChannelType.GUILD_CATEGORY):
if ctype in GUILD_CHANS:
guild_id = await app.db.fetchval(""" guild_id = await app.db.fetchval("""
SELECT guild_id SELECT guild_id
FROM guild_channels FROM guild_channels
@ -34,10 +35,25 @@ async def channel_check(user_id, channel_id):
await guild_check(user_id, guild_id) await guild_check(user_id, guild_id)
return guild_id return guild_id
if ctype == ChannelType.DM:
parties = await app.db.fetchval("""
SELECT party1_id, party2_id
FROM dm_channels
WHERE id = $1 AND (party1_id = $2 OR party2_id = $2)
""", channel_id, user_id)
# get the id of the other party
parties.remove(user_id)
return parties[0]
@bp.route('/<int:channel_id>', methods=['GET']) @bp.route('/<int:channel_id>', methods=['GET'])
async def get_channel(channel_id): async def get_channel(channel_id):
"""Get a single channel's information"""
user_id = await token_check() user_id = await token_check()
# channel_check takes care of checking
# DMs and group DMs
await channel_check(user_id, channel_id) await channel_check(user_id, channel_id)
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
@ -47,6 +63,129 @@ async def get_channel(channel_id):
return jsonify(chan) return jsonify(chan)
async def __guild_chan_sql(guild_id, channel_id, field: str) -> str:
"""Update a guild's channel id field to NULL,
if it was set to the given channel id before."""
return await app.db.execute(f"""
UPDATE guilds
SET {field} = NULL
WHERE guilds.id = $1 AND {field} = $2
""", guild_id, channel_id)
async def _update_guild_chan_text(guild_id: int, channel_id: int):
res_embed = await __guild_chan_sql(
guild_id, channel_id, 'embed_channel_id')
res_widget = await __guild_chan_sql(
guild_id, channel_id, 'widget_channel_id')
res_system = await __guild_chan_sql(
guild_id, channel_id, 'system_channel_id')
# if none of them were actually updated,
# ignore and dont dispatch anything
if 'UPDATE 1' not in (res_embed, res_widget, res_system):
return
# at least one of the fields were updated,
# dispatch GUILD_UPDATE
guild = await app.storage.get_guild(guild_id)
await app.dispatcher.dispatch_guild(
guild_id, 'GUILD_UPDATE', guild)
async def _update_guild_chan_voice(guild_id: int, channel_id: int):
res = await __guild_chan_sql(guild_id, channel_id, 'afk_channel_id')
# guild didnt update
if res == 'UPDATE 0':
return
guild = await app.storage.get_guild(guild_id)
await app.dispatcher.dispatch_guild(
guild_id, 'GUILD_UPDATE', guild)
async def _update_guild_chan_cat(guild_id: int, channel_id: int):
# get all channels that were childs of the category
childs = await app.db.fetch("""
SELECT id
FROM guild_channels
WHERE guild_id = $1 AND parent_id = $2
""", guild_id, channel_id)
childs = [c['id'] for c in childs]
# update every child channel to parent_id = NULL
await app.db.execute("""
UPDATE guild_channels
SET parent_id = NULL
WHERE guild_id = $1 AND parent_id = $2
""", guild_id, channel_id)
# tell all people in the guild of the category removal
for child_id in childs:
child = await app.storage.get_channel(child_id)
await app.dispatcher.dispatch_guild(
guild_id, 'CHANNEL_UPDATE', child
)
@bp.route('/<int:channel_id>', methods=['DELETE'])
async def close_channel(channel_id):
user_id = await token_check()
chan_type = await app.storage.get_chan_type(channel_id)
ctype = ChannelType(chan_type)
if ctype in GUILD_CHANS:
guild_id = await channel_check(user_id, channel_id)
chan = await app.storage.get_channel(channel_id)
# the selected function will take care of checking
# the sanity of tables once the channel becomes deleted.
_update_func = {
ChannelType.GUILD_TEXT: _update_guild_chan_text,
ChannelType.GUILD_VOICE: _update_guild_chan_voice,
ChannelType.GUILD_CATEGORY: _update_guild_chan_cat,
}[ctype]
await _update_func(guild_id, channel_id)
# this should take care of deleting all messages as well
# (if any)
await app.db.execute("""
DELETE FROM guild_channels
WHERE id = $1
""", channel_id)
await app.dispatcher.dispatch_guild(
guild_id, 'CHANNEL_DELETE', chan)
return jsonify(chan)
if ctype == ChannelType.DM:
chan = await app.storage.get_channel(channel_id)
# we don't ever actually delete DM channels off the database.
# instead, we close the channel for the user that is making
# the request via removing the link between them and
# the channel on dm_channel_state
await app.db.execute("""
DELETE FROM dm_channel_state (user_id, dm_id)
VALUES ($1, $2)
""", user_id, channel_id)
# nothing happens to the other party of the dm channel
await app.dispacher.dispatch_user(user_id, 'CHANNEL_DELETE', chan)
return jsonify(chan)
if ctype == ChannelType.GROUP_DM:
# TODO: group dm
pass
return '', 404
@bp.route('/<int:channel_id>/messages', methods=['GET']) @bp.route('/<int:channel_id>/messages', methods=['GET'])
async def get_messages(channel_id): async def get_messages(channel_id):
user_id = await token_check() user_id = await token_check()
@ -82,7 +221,6 @@ async def get_single_message(channel_id, message_id):
await channel_check(user_id, channel_id) await channel_check(user_id, channel_id)
# TODO: check READ_MESSAGE_HISTORY permissions # TODO: check READ_MESSAGE_HISTORY permissions
message = await app.storage.get_message(message_id) message = await app.storage.get_message(message_id)
if not message: if not message:
@ -120,6 +258,8 @@ async def create_message(channel_id):
) )
# TODO: dispatch_channel # TODO: dispatch_channel
# we really need dispatch_channel to make dm messages work,
# since they aren't part of any existing guild.
payload = await app.storage.get_message(message_id) payload = await app.storage.get_message(message_id)
await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload) await app.dispatcher.dispatch_guild(guild_id, 'MESSAGE_CREATE', payload)
@ -266,6 +406,7 @@ async def delete_pin(channel_id, message_id):
timestamp = snowflake_datetime(row['message_id']) timestamp = snowflake_datetime(row['message_id'])
# TODO: dispatch_channel
await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', { await app.dispatcher.dispatch_guild(guild_id, 'CHANNEL_PINS_UPDATE', {
'channel_id': str(channel_id), 'channel_id': str(channel_id),
'last_pin_timestamp': timestamp.isoformat() 'last_pin_timestamp': timestamp.isoformat()
@ -279,6 +420,7 @@ async def trigger_typing(channel_id):
user_id = await token_check() user_id = await token_check()
guild_id = await channel_check(user_id, channel_id) guild_id = await channel_check(user_id, channel_id)
# TODO: dispatch_channel
await app.dispatcher.dispatch_guild(guild_id, 'TYPING_START', { await app.dispatcher.dispatch_guild(guild_id, 'TYPING_START', {
'channel_id': str(channel_id), 'channel_id': str(channel_id),
'user_id': str(user_id), 'user_id': str(user_id),

View File

@ -2,8 +2,11 @@ from quart import Blueprint, jsonify, request, current_app as app
from asyncpg import UniqueViolationError from asyncpg import UniqueViolationError
from ..auth import token_check from ..auth import token_check
from ..snowflake import get_snowflake
from ..errors import Forbidden, BadRequest from ..errors import Forbidden, BadRequest
from ..schemas import validate, USER_SETTINGS from ..schemas import validate, USER_SETTINGS, CREATE_DM, CREATE_GROUP_DM
from .guilds import guild_check
bp = Blueprint('user', __name__) bp = Blueprint('user', __name__)
@ -84,15 +87,31 @@ async def get_me_guilds():
@bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE']) @bp.route('/@me/guilds/<int:guild_id>', methods=['DELETE'])
async def leave_guild(guild_id): async def leave_guild(guild_id: int):
user_id = await token_check() user_id = await token_check()
await guild_check(user_id, guild_id)
await app.db.execute(""" await app.db.execute("""
DELETE FROM members DELETE FROM members
WHERE user_id = $1 AND guild_id = $2 WHERE user_id = $1 AND guild_id = $2
""", user_id, guild_id) """, user_id, guild_id)
# TODO: something to dispatch events to the users # first dispatch guild delete to the user,
# then remove from the guild,
# then tell the others that the member was removed
await app.dispatcher.dispatch_user_guild(
user_id, guild_id, 'GUILD_DELETE', {
'id': str(guild_id),
'unavailable': False,
}
)
await app.dispatcher.unsub_guild(guild_id, user_id)
await app.dispatcher.dispatch_guild('GUILD_MEMBER_REMOVE', {
'guild_id': str(guild_id),
'user': await app.storage.get_user(user_id)
})
return '', 204 return '', 204
@ -102,14 +121,75 @@ async def get_connections():
pass pass
# @bp.route('/@me/channels', methods=['GET']) @bp.route('/@me/channels', methods=['GET'])
async def get_dms(): async def get_dms():
pass user_id = await token_check()
dms = await app.storage.get_dms(user_id)
return jsonify(dms)
# @bp.route('/@me/channels', methods=['POST']) async def try_dm_state(user_id, dm_id):
"""Try insertin the user into the dm state
for the given DM."""
try:
await app.db.execute("""
INSERT INTO dm_channel_state (id, dm_id)
VALUES ($1, $2)
""", user_id, dm_id)
except UniqueViolationError:
# if already in state, ignore
pass
async def create_dm(user_id, recipient_id):
dm_id = get_snowflake()
try:
await app.db.execute("""
INSERT INTO dm_channels (id, party1_id, party2_id)
VALUES ($1, $2, $3)
""", dm_id, user_id, recipient_id)
await try_dm_state(user_id, dm_id)
except UniqueViolationError:
# the dm already exists
dm_id = await app.db.fetchval("""
SELECT id
FROM dm_channels
WHERE (party1_id = $1 OR party2_id = $1) AND
(party2_id = $2 OR party2_id = $2)
""", user_id, recipient_id)
dm = await app.storage.get_dm(dm_id, user_id)
return jsonify(dm)
@bp.route('/@me/channels', methods=['POST'])
async def start_dm(): async def start_dm():
pass """Create a DM with a user."""
user_id = await token_check()
j = validate(await request.get_json(), CREATE_DM)
recipient_id = j['recipient_id']
return await create_dm(user_id, recipient_id)
@bp.route('/<int:user_id>/channels', methods=['POST'])
async def create_group_dm(p_user_id: int):
"""Create a DM or a Group DM with user(s)."""
user_id = await token_check()
assert user_id == p_user_id
j = validate(await request.get_json(), CREATE_GROUP_DM)
recipients = j['recipients']
if list(recipients) == 1:
# its a group dm with 1 user... a dm!
return await create_dm(user_id, int(recipients[0]))
# TODO: group dms
return '', 500
@bp.route('/@me/notes/<int:target_id>', methods=['PUT']) @bp.route('/@me/notes/<int:target_id>', methods=['PUT'])

View File

@ -17,6 +17,11 @@ class ChannelType(EasyEnum):
GUILD_CATEGORY = 4 GUILD_CATEGORY = 4
GUILD_CHANS = (ChannelType.GUILD_TEXT,
ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY)
class ActivityType(EasyEnum): class ActivityType(EasyEnum):
PLAYING = 0 PLAYING = 0
STREAMING = 1 STREAMING = 1

View File

@ -279,3 +279,16 @@ RELATIONSHIP = {
'default': RelationshipType.FRIEND.value 'default': RelationshipType.FRIEND.value
} }
} }
CREATE_DM = {
'recipient_id': {
'type': 'snowflake',
'required': True
}
}
CREATE_GROUP_DM = {
'type': 'list',
'required': True,
'schema': {'type': 'snowflake'}
}

View File

@ -36,6 +36,16 @@ async def _set_json(con):
) )
def _filter_recipients(recipients: List[Dict[str, Any]], user_id: int):
"""Filter recipients in a list of recipients, removing
the one that is reundant (ourselves)."""
user_id = str(user_id)
return filter(
lambda recipient: recipient['id'] != user_id,
recipients)
class Storage: class Storage:
"""Class for common SQL statements.""" """Class for common SQL statements."""
def __init__(self, db): def __init__(self, db):
@ -215,12 +225,19 @@ class Storage:
members = await self.get_member_multi(guild_id, mids) members = await self.get_member_multi(guild_id, mids)
return members return members
async def _chan_last_message(self, channel_id: int):
return await self.db.fetch("""
SELECT MAX(id)
FROM messages
WHERE channel_id = $1
""", channel_id)
async def _channels_extra(self, row) -> Dict: async def _channels_extra(self, row) -> Dict:
"""Fill in more information about a channel.""" """Fill in more information about a channel."""
channel_type = row['type'] channel_type = row['type']
# TODO: dm and group dm?
chan_type = ChannelType(channel_type) chan_type = ChannelType(channel_type)
if chan_type == ChannelType.GUILD_TEXT: if chan_type == ChannelType.GUILD_TEXT:
topic = await self.db.fetchval(""" topic = await self.db.fetchval("""
SELECT topic FROM guild_text_channels SELECT topic FROM guild_text_channels
@ -229,6 +246,8 @@ class Storage:
return {**row, **{ return {**row, **{
'topic': topic, 'topic': topic,
'last_message_id': str(
await self._chan_last_message(row['id']))
}} }}
elif chan_type == ChannelType.GUILD_VOICE: elif chan_type == ChannelType.GUILD_VOICE:
vrow = await self.db.fetchval(""" vrow = await self.db.fetchval("""
@ -240,7 +259,8 @@ class Storage:
log.warning('unknown channel type: {}', chan_type) log.warning('unknown channel type: {}', chan_type)
async def get_chan_type(self, channel_id) -> int: async def get_chan_type(self, channel_id: int) -> int:
"""Get the channel type integer, given channel ID."""
return await self.db.fetchval(""" return await self.db.fetchval("""
SELECT channel_type SELECT channel_type
FROM channels FROM channels
@ -275,13 +295,14 @@ class Storage:
return list(map(_overwrite_convert, overwrite_rows)) return list(map(_overwrite_convert, overwrite_rows))
async def get_channel(self, channel_id) -> Dict[str, Any]: async def get_channel(self, channel_id: int) -> Dict[str, Any]:
"""Fetch a single channel's information.""" """Fetch a single channel's information."""
chan_type = await self.get_chan_type(channel_id) chan_type = await self.get_chan_type(channel_id)
ctype = ChannelType(chan_type)
if ChannelType(chan_type) in (ChannelType.GUILD_TEXT, if ctype in (ChannelType.GUILD_TEXT,
ChannelType.GUILD_VOICE, ChannelType.GUILD_VOICE,
ChannelType.GUILD_CATEGORY): ChannelType.GUILD_CATEGORY):
base = await self.db.fetchrow(""" base = await self.db.fetchrow("""
SELECT id, guild_id::text, parent_id, name, position, nsfw SELECT id, guild_id::text, parent_id, name, position, nsfw
FROM guild_channels FROM guild_channels
@ -297,10 +318,35 @@ class Storage:
res['id'] = str(res['id']) res['id'] = str(res['id'])
return res return res
else: elif ctype == ChannelType.DM:
# TODO: dms and group dms dm_row = await self.db.fetchrow("""
SELECT party1_id, party2_id
FROM dm_channels
WHERE id = $1
""", channel_id)
drow = dict(dm_row)
drow['type'] = chan_type
drow['last_message_id'] = str(
await self._chan_last_message(channel_id))
# dms have just two recipients.
drow['recipients'] = [
await self.get_user(drow['party1_id']),
await self.get_user(drow['party2_id'])
]
drow.pop('party1_id')
drow.pop('party2_id')
drow['id'] = str(drow['id'])
return drow
elif ctype == ChannelType.GROUP_DM:
pass pass
return None
async def get_channel_data(self, guild_id) -> List[Dict]: async def get_channel_data(self, guild_id) -> List[Dict]:
"""Get channel information on a guild""" """Get channel information on a guild"""
channel_basics = await self.db.fetch(""" channel_basics = await self.db.fetch("""
@ -574,6 +620,7 @@ class Storage:
return dinv return dinv
async def get_user_settings(self, user_id: int) -> Dict[str, Any]: async def get_user_settings(self, user_id: int) -> Dict[str, Any]:
"""Get current user settings."""
row = await self._fetchrow_with_json(""" row = await self._fetchrow_with_json("""
SELECT * SELECT *
FROM user_settings FROM user_settings
@ -688,3 +735,49 @@ class Storage:
res.append(drow) res.append(drow)
return res return res
async def get_dm(self, dm_id: int, user_id: int = None):
dm_chan = await self.get_channel(dm_id)
if user_id:
dm_chan['recipients'] = _filter_recipients(
dm_chan['recipients'], user_id
)
return dm_chan
async def get_dms(self, user_id: int) -> List[Dict[str, Any]]:
"""Get all DM channels for a user, including group DMs.
This will only fetch channels the user has in their state,
which is different than the whole list of DM channels.
"""
dm_ids = await self.db.fetch("""
SELECT id
FROM dm_channel_state
WHERE user_id = $1
""", user_id)
res = []
for dm_id in dm_ids:
dm_chan = await self.get_dm(dm_id, user_id)
res.append(dm_chan)
return res
async def get_all_dms(self, user_id: int) -> List[Dict[str, Any]]:
"""Get all DMs for a user, regardless of the DM state."""
dm_ids = await self.db.fetch("""
SELECT id
FROM dm_channels
WHERE party1_id = $1 OR party2_id = $2
""", user_id)
res = []
for dm_id in dm_ids:
dm_chan = await self.get_dm(dm_id, user_id)
res.append(dm_chan)
return res

View File

@ -261,6 +261,13 @@ CREATE TABLE IF NOT EXISTS dm_channels (
); );
CREATE TABLE IF NOT EXISTS dm_channel_state (
user_id bigint REFERENCES users (id) ON DELETE CASCADE,
dm_id bigint REFERENCES dm_channels (id) ON DELETE CASCADE,
PRIMARY KEY (user_id, dm_id)
);
CREATE TABLE IF NOT EXISTS group_dm_channels ( CREATE TABLE IF NOT EXISTS group_dm_channels (
id bigint REFERENCES channels (id) ON DELETE CASCADE, id bigint REFERENCES channels (id) ON DELETE CASCADE,
owner_id bigint REFERENCES users (id), owner_id bigint REFERENCES users (id),
@ -440,7 +447,7 @@ CREATE TABLE IF NOT EXISTS embeds (
CREATE TABLE IF NOT EXISTS messages ( CREATE TABLE IF NOT EXISTS messages (
id bigint PRIMARY KEY, id bigint PRIMARY KEY,
channel_id bigint REFERENCES channels (id), channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
-- those are mutually exclusive, only one of them -- those are mutually exclusive, only one of them
-- can NOT be NULL at a time. -- can NOT be NULL at a time.
@ -486,7 +493,7 @@ CREATE TABLE IF NOT EXISTS message_reactions (
); );
CREATE TABLE IF NOT EXISTS channel_pins ( CREATE TABLE IF NOT EXISTS channel_pins (
channel_id bigint REFERENCES channels (id), channel_id bigint REFERENCES channels (id) ON DELETE CASCADE,
message_id bigint REFERENCES messages (id), message_id bigint REFERENCES messages (id) ON DELETE CASCADE,
PRIMARY KEY (channel_id, message_id) PRIMARY KEY (channel_id, message_id)
); );