Merge branch 'group-dms' into 'master'

Group DMs

See merge request litecord/litecord!18
This commit is contained in:
Luna 2019-02-17 03:24:08 +00:00
commit da7ef70458
16 changed files with 532 additions and 42 deletions

View File

@ -32,7 +32,9 @@ from .icons import bp as icons
from .nodeinfo import bp as nodeinfo
from .static import bp as static
from .attachments import bp as attachments
from .dm_channels import bp as dm_channels
__all__ = ['gateway', 'auth', 'users', 'guilds', 'channels',
'webhooks', 'science', 'voice', 'invites', 'relationships',
'dms', 'icons', 'nodeinfo', 'static', 'attachments']
'dms', 'icons', 'nodeinfo', 'static', 'attachments',
'dm_channels']

View File

@ -23,13 +23,14 @@ from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
from litecord.auth import token_check
from litecord.enums import ChannelType, GUILD_CHANS
from litecord.enums import ChannelType, GUILD_CHANS, MessageType
from litecord.errors import ChannelNotFound
from litecord.schemas import (
validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL
validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE
)
from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.system_messages import send_sys_message
log = Logger(__name__)
bp = Blueprint('channels', __name__)
@ -405,31 +406,73 @@ async def _update_voice_channel(channel_id: int, j: dict):
await _common_guild_chan(channel_id, j)
async def _update_group_dm(channel_id: int, j: dict, author_id: int):
if 'name' in j:
await app.db.execute("""
UPDATE group_dm_channels
SET name = $1
WHERE id = $2
""", j['name'], channel_id)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
)
if 'icon' in j:
new_icon = await app.icons.update(
'channel-icons', channel_id, j['icon'], always_icon=True
)
await app.db.execute("""
UPDATE group_dm_channels
SET icon = $1
WHERE id = $2
""", new_icon.icon_hash, channel_id)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
)
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH'])
async def update_channel(channel_id):
"""Update a channel's information"""
user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id)
if ctype not in GUILD_CHANS:
raise ChannelNotFound('Can not edit non-guild channels.')
if ctype not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
ChannelType.GROUP_DM):
raise ChannelNotFound('unable to edit unsupported chan type')
await channel_perm_check(user_id, channel_id, 'manage_channels')
j = validate(await request.get_json(), CHAN_UPDATE)
is_guild = ctype in GUILD_CHANS
# TODO: categories?
if is_guild:
await channel_perm_check(user_id, channel_id, 'manage_channels')
j = validate(await request.get_json(),
CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
# TODO: categories
update_handler = {
ChannelType.GUILD_TEXT: _update_text_channel,
ChannelType.GUILD_VOICE: _update_voice_channel,
ChannelType.GROUP_DM: _update_group_dm,
}[ctype]
await _update_channel_common(channel_id, guild_id, j)
await update_handler(channel_id, j)
if is_guild:
await _update_channel_common(channel_id, guild_id, j)
await update_handler(channel_id, j, user_id)
chan = await app.storage.get_channel(channel_id)
if is_guild:
await app.dispatcher.dispatch(
'guild', guild_id, 'CHANNEL_UPDATE', chan)
else:
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_UPDATE', chan)
await app.dispatcher.dispatch('guild', guild_id, 'CHANNEL_UPDATE', chan)
return jsonify(chan)

View File

@ -17,6 +17,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import Union, List
from quart import current_app as app
from litecord.enums import ChannelType, GUILD_CHANS
@ -53,16 +55,23 @@ async def guild_owner_check(user_id: int, guild_id: int):
raise Forbidden('You are not the owner of the guild')
async def channel_check(user_id, channel_id):
async def channel_check(user_id, channel_id, *,
only: Union[ChannelType, List[ChannelType]] = None):
"""Check if the current user is authorized
to read the channel's information."""
chan_type = await app.storage.get_chan_type(channel_id)
if chan_type is None:
raise ChannelNotFound(f'channel type not found')
raise ChannelNotFound('channel type not found')
ctype = ChannelType(chan_type)
if only and not isinstance(only, list):
only = [only]
if only and ctype not in only:
raise ChannelNotFound('invalid channel type')
if ctype in GUILD_CHANS:
guild_id = await app.db.fetchval("""
SELECT guild_id
@ -77,6 +86,15 @@ async def channel_check(user_id, channel_id):
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
return ctype, peer_id
if ctype == ChannelType.GROUP_DM:
owner_id = await app.db.fetchval("""
SELECT owner_id
FROM group_dm_channels
WHERE id = $1
""", channel_id)
return ctype, owner_id
async def guild_perm_check(user_id, guild_id, permission: str):
"""Check guild permissions for a user."""

View File

@ -0,0 +1,182 @@
"""
Litecord
Copyright (C) 2018-2019 Luna Mendes
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from quart import Blueprint, current_app as app, jsonify
from logbook import Logger
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check
from litecord.enums import ChannelType, MessageType
from litecord.errors import BadRequest, Forbidden
from litecord.snowflake import get_snowflake
from litecord.system_messages import send_sys_message
from litecord.pubsub.channel import gdm_recipient_view
log = Logger(__name__)
bp = Blueprint('dm_channels', __name__)
async def _raw_gdm_add(channel_id, user_id):
await app.db.execute("""
INSERT INTO group_dm_members (id, member_id)
VALUES ($1, $2)
""", channel_id, user_id)
async def _raw_gdm_remove(channel_id, user_id):
await app.db.execute("""
DELETE FROM group_dm_members
WHERE id = $1 AND member_id = $2
""", channel_id, user_id)
async def _gdm_create(user_id, peer_id) -> int:
"""Create a group dm, given two users.
Returns the new GDM id.
"""
channel_id = get_snowflake()
await app.db.execute("""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""", channel_id, ChannelType.GROUP_DM.value)
await app.db.execute("""
INSERT INTO group_dm_channels (id, owner_id, name, icon)
VALUES ($1, $2, NULL, NULL)
""", channel_id, user_id)
await _raw_gdm_add(channel_id, user_id)
await _raw_gdm_add(channel_id, peer_id)
await app.dispatcher.sub('channel', channel_id, user_id)
await app.dispatcher.sub('channel', channel_id, peer_id)
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch('channel', channel_id, 'CHANNEL_CREATE', chan)
return channel_id
async def _gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
"""Add a recipient to a Group DM.
Dispatches:
- A system message with the join (depending of user_id)
- A CHANNEL_CREATE to the peer.
- A CHANNEL_UPDATE to all remaining recipients.
"""
await _raw_gdm_add(channel_id, peer_id)
chan = await app.storage.get_channel(channel_id)
# the reasoning behind gdm_recipient_view is in its docstring.
await app.dispatcher.dispatch(
'user', peer_id, 'CHANNEL_CREATE', gdm_recipient_view(chan, peer_id))
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_UPDATE', chan)
await app.dispatcher.sub('channel', peer_id)
if user_id:
await send_sys_message(
app, channel_id, MessageType.RECIPIENT_ADD,
user_id, peer_id
)
async def _gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
"""Remove a member from a GDM.
Dispatches:
- A system message with the leave or forced removal (depending if user_id)
exists or not.
- A CHANNEL_DELETE to the peer.
- A CHANNEL_UPDATE to all remaining recipients.
"""
await _raw_gdm_remove(channel_id, peer_id)
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch(
'user', peer_id, 'CHANNEL_DELETE', gdm_recipient_view(chan, user_id))
await app.dispatcher.unsub('channel', peer_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_RECIPIENT_REMOVE', {
'channel_id': str(channel_id),
'user': await app.storage.get_user(peer_id)
}
)
if user_id:
await send_sys_message(
app, channel_id, MessageType.RECIPIENT_REMOVE,
user_id, peer_id
)
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['PUT'])
async def add_to_group_dm(dm_chan, peer_id):
"""Adds a member to a group dm OR creates a group dm."""
user_id = await token_check()
# other_id is the owner of the group dm (gdm) if the
# given channel is a gdm
# other_id is the peer of the dm if the given channel is a dm
ctype, other_id = await channel_check(
user_id, dm_chan,
only=[ChannelType.DM, ChannelType.GROUP_DM]
)
# check relationship with the given user id
# and the user id making the request
friends = await app.user_storage.are_friends_with(user_id, peer_id)
if not friends:
raise BadRequest('Cant insert peer into dm')
if ctype == ChannelType.DM:
dm_chan = await _gdm_create(
user_id, other_id
)
await _gdm_add_recipient(dm_chan, peer_id, user_id=user_id)
return jsonify(
await app.storage.get_channel(dm_chan)
)
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['DELETE'])
async def remove_from_group_dm(dm_chan, peer_id):
"""Remove users from group dm."""
user_id = await token_check()
_ctype, owner_id = await channel_check(
user_id, dm_chan, only=ChannelType.GROUP_DM
)
if owner_id != user_id:
raise Forbidden('You are now the owner of the group DM')
await _gdm_remove_recipient(dm_chan, peer_id)
return '', 204

View File

@ -76,3 +76,9 @@ async def _get_user_avatar(user_id, avatar_file):
# @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>')
async def get_app_icon(application_id, icon_hash, ext):
pass
@bp.route('/channel-icons/<int:channel_id>/<icon_file>', methods=['GET'])
async def _get_gdm_icon(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file)
return await send_icon('channel-icons', guild_id, icon_hash, ext=ext)

View File

@ -384,9 +384,10 @@ async def get_profile(peer_id: int):
return '', 404
mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id)
friends = await app.user_storage.are_friends_with(user_id, peer_id)
# don't return a proper card if no guilds are being shared.
if not mutuals:
if not mutuals and not friends:
return '', 404
# actual premium status is determined by that

View File

@ -103,6 +103,15 @@ class EventDispatcher:
for key in keys:
await self.subscribe(backend_str, key, identifier)
async def mass_sub(self, identifier: Any,
backends: List[tuple]):
"""Mass subscribe to many backends at once."""
for backend_str, keys in backends:
log.debug('subscribing {} to {} keys in backend {}',
identifier, len(keys), backend_str)
await self.sub_many(backend_str, identifier, keys)
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
"""Dispatch an event to the backend.

View File

@ -77,14 +77,14 @@ class LitecordError(Exception):
@property
def message(self) -> str:
"""Get an error's message string."""
err_code = getattr(self, 'error_code', None)
if err_code is not None:
return ERR_MSG_MAP.get(err_code) or self.args[0]
try:
return self.args[0]
except IndexError:
err_code = getattr(self, 'error_code', None)
if err_code is not None:
return ERR_MSG_MAP.get(err_code) or self.args[0]
return repr(self)
@property

View File

@ -372,11 +372,16 @@ class GatewayWebsocket:
# user, fetch info
uready = await self._user_ready()
private_channels = (
await self.user_storage.get_dms(user_id) +
await self.user_storage.get_gdms(user_id)
)
await self.dispatch('READY', {**{
'v': 6,
'user': user,
'private_channels': await self.user_storage.get_dms(user_id),
'private_channels': private_channels,
'guilds': guilds,
'session_id': self.state.session_id,
@ -437,17 +442,24 @@ class GatewayWebsocket:
by GuildDispatcher.sub
"""
user_id = self.state.user_id
guild_ids = await self._guild_ids()
log.info('subscribing to {} guilds', len(guild_ids))
await self.ext.dispatcher.sub_many('guild', user_id, guild_ids)
# subscribe the user to all dms they have OPENED.
dms = await self.user_storage.get_dms(user_id)
dm_ids = [int(dm['id']) for dm in dms]
# fetch all group dms the user is a member of.
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
log.info('subscribing to {} guilds', len(guild_ids))
log.info('subscribing to {} dms', len(dm_ids))
await self.ext.dispatcher.sub_many('channel', user_id, dm_ids)
log.info('subscribing to {} group dms', len(gdm_ids))
await self.ext.dispatcher.mass_sub(user_id, [
('guild', guild_ids),
('channel', dm_ids),
('channel', gdm_ids)
])
if not self.state.bot:
# subscribe to all friends

View File

@ -161,21 +161,18 @@ def parse_data_uri(string) -> tuple:
def _gen_update_sql(scope: str) -> str:
field = {
'user': 'avatar',
'guild': 'icon'
'guild': 'icon',
'channel-icons': 'icon',
}[scope]
table = {
'user': 'users',
'guild': 'guilds'
}[scope]
col = {
'user': 'id',
'guild': 'id'
'guild': 'guilds',
'channel-icons': 'group_dm_channels'
}[scope]
return f"""
SELECT {field} FROM {table} WHERE {col} = $1
SELECT {field} FROM {table} WHERE id = $1
"""
@ -277,6 +274,9 @@ class IconManager:
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Icon:
"""Get any icon."""
if icon_hash is None:
return None
log.debug('GET {} {} {}', scope, key, icon_hash)
key = str(key)
@ -409,6 +409,12 @@ class IconManager:
WHERE icon = $1
""", icon.icon_hash)
await self.storage.db.execute("""
UPDATE group_dm_channels
SET icon = NULL
WHERE icon = $1
""", icon.icon_hash)
await self.storage.db.execute("""
DELETE FROM icons
WHERE hash = $1

View File

@ -22,10 +22,32 @@ from typing import Any
from logbook import Logger
from .dispatcher import DispatcherWithState
from litecord.enums import ChannelType
from litecord.utils import index_by_func
log = Logger(__name__)
def gdm_recipient_view(orig: dict, user_id: int) -> dict:
"""Create a copy of the original channel object that doesn't
show the user we are dispatching it to.
this only applies to group dms and discords' api design that says
a group dms' recipients must not show the original user.
"""
# make a copy or the original channel object
data = dict(orig)
idx = index_by_func(
lambda user: user['id'] == str(user_id),
data['recipients']
)
data['recipients'].pop(idx)
return data
class ChannelDispatcher(DispatcherWithState):
"""Main channel Pub/Sub logic."""
KEY_TYPE = int
@ -62,7 +84,19 @@ class ChannelDispatcher(DispatcherWithState):
await self.unsub(channel_id, user_id)
continue
cur_sess = await self._dispatch_states(states, event, data)
cur_sess = 0
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
and data.get('type') == ChannelType.GROUP_DM.value:
# we edit the channel payload so it doesn't show
# the user as a recipient
new_data = gdm_recipient_view(data, user_id)
cur_sess = await self._dispatch_states(
states, event, new_data)
else:
cur_sess = await self._dispatch_states(
states, event, data)
sessions.extend(cur_sess)
dispatched += len(cur_sess)

View File

@ -581,6 +581,14 @@ CREATE_GROUP_DM = {
},
}
GROUP_DM_UPDATE = {
'name': {
'type': 'guild_name',
'required': False
},
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
}
SPECIFIC_FRIEND = {
'username': {'type': 'username'},
'discriminator': {'type': 'discriminator'}

View File

@ -353,7 +353,38 @@ class Storage:
return list(map(_overwrite_convert, overwrite_rows))
async def get_channel(self, channel_id: int) -> Dict[str, Any]:
async def gdm_recipient_ids(self, channel_id: int) -> List[int]:
"""Get the list of user IDs that are recipients of the
given Group DM."""
user_ids = await self.db.fetch("""
SELECT member_id
FROM group_dm_members
JOIN users
ON member_id = users.id
WHERE group_dm_members.id = $1
ORDER BY username DESC
""", channel_id)
return [r['member_id'] for r in user_ids]
async def _gdm_recipients(self, channel_id: int,
reference_id: int = None) -> List[int]:
"""Get the list of users that are recipients of the
given Group DM."""
recipients = await self.gdm_recipient_ids(channel_id)
res = []
for user_id in recipients:
if user_id == reference_id:
continue
res.append(
await self.get_user(user_id)
)
return res
async def get_channel(self, channel_id: int, **kwargs) -> Dict[str, Any]:
"""Fetch a single channel's information."""
chan_type = await self.get_chan_type(channel_id)
ctype = ChannelType(chan_type)
@ -387,7 +418,8 @@ class Storage:
drow['type'] = chan_type
drow['last_message_id'] = await self.chan_last_message_str(
channel_id)
channel_id
)
# dms have just two recipients.
drow['recipients'] = [
@ -401,8 +433,22 @@ class Storage:
drow['id'] = str(drow['id'])
return drow
elif ctype == ChannelType.GROUP_DM:
# TODO: group dms
pass
gdm_row = await self.db.fetchrow("""
SELECT id::text, owner_id::text, name, icon
FROM group_dm_channels
WHERE id = $1
""", channel_id)
drow = dict(gdm_row)
drow['type'] = chan_type
drow['recipients'] = await self._gdm_recipients(
channel_id, kwargs.get('user_id')
)
drow['last_message_id'] = await self.chan_last_message_str(
channel_id
)
return drow
return None

View File

@ -17,11 +17,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from logbook import Logger
from litecord.snowflake import get_snowflake
from litecord.enums import MessageType
log = Logger(__name__)
async def _handle_pin_msg(app, channel_id, pinned_id, author_id):
async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
"""Handle a message pin."""
new_id = get_snowflake()
@ -41,12 +44,108 @@ async def _handle_pin_msg(app, channel_id, pinned_id, author_id):
return new_id
# TODO: decrease repetition between add and remove handlers
async def _handle_recp_add(app, channel_id, author_id, peer_id):
new_id = get_snowflake()
await app.db.execute(
"""
INSERT INTO messages
(id, channel_id, author_id, webhook_id,
content, message_type)
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
f'<@{peer_id}>',
MessageType.RECIPIENT_ADD.value
)
return new_id
async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
new_id = get_snowflake()
await app.db.execute(
"""
INSERT INTO messages
(id, channel_id, author_id, webhook_id,
content, message_type)
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
f'<@{peer_id}>',
MessageType.RECIPIENT_REMOVE.value
)
return new_id
async def _handle_gdm_name_edit(app, channel_id, author_id):
new_id = get_snowflake()
gdm_name = await app.db.fetchval("""
SELECT name FROM group_dm_channels
WHERE id = $1
""", channel_id)
if not gdm_name:
log.warning('no gdm name found for sys message')
return
await app.db.execute(
"""
INSERT INTO messages
(id, channel_id, author_id, webhook_id,
content, message_type)
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
gdm_name,
MessageType.CHANNEL_NAME_CHANGE.value
)
return new_id
async def _handle_gdm_icon_edit(app, channel_id, author_id):
new_id = get_snowflake()
await app.db.execute(
"""
INSERT INTO messages
(id, channel_id, author_id, webhook_id,
content, message_type)
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
'',
MessageType.CHANNEL_ICON_CHANGE.value
)
return new_id
async def send_sys_message(app, channel_id: int, m_type: MessageType,
*args, **kwargs) -> int:
"""Send a system message."""
handler = {
MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg,
}[m_type]
try:
handler = {
MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg,
# gdm specific
MessageType.RECIPIENT_ADD: _handle_recp_add,
MessageType.RECIPIENT_REMOVE: _handle_recp_rmv,
MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit,
MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit
}[m_type]
except KeyError:
raise ValueError('Invalid system message type')
message_id = await handler(app, channel_id, *args, **kwargs)

View File

@ -338,3 +338,26 @@ class UserStorage:
)
)
""", user_id, peer_id)
async def get_gdms_internal(self, user_id) -> List[int]:
"""Return a list of Group DM IDs the user is a member of."""
rows = await self.db.fetch("""
SELECT id
FROM group_dm_members
WHERE member_id = $1
""", user_id)
return [r['id'] for r in rows]
async def get_gdms(self, user_id) -> List[Dict[str, Any]]:
"""Get list of group DMs a user is in."""
gdm_ids = await self.get_gdms_internal(user_id)
res = []
for gdm_id in gdm_ids:
res.append(
await self.storage.get_channel(gdm_id, user_id=user_id)
)
return res

3
run.py
View File

@ -35,7 +35,7 @@ import config
from litecord.blueprints import (
gateway, auth, users, guilds, channels, webhooks, science,
voice, invites, relationships, dms, icons, nodeinfo, static,
attachments
attachments, dm_channels
)
# those blueprints are separated from the "main" ones
@ -128,6 +128,7 @@ def set_blueprints(app_):
voice: '/voice',
invites: None,
dms: '/users',
dm_channels: '/channels',
fake_store: None,