Merge branch 'group-dms' into 'master'

Group DMs

See merge request litecord/litecord!18
This commit is contained in:
Luna 2019-02-17 03:24:08 +00:00
commit da7ef70458
16 changed files with 532 additions and 42 deletions

View File

@ -32,7 +32,9 @@ from .icons import bp as icons
from .nodeinfo import bp as nodeinfo from .nodeinfo import bp as nodeinfo
from .static import bp as static from .static import bp as static
from .attachments import bp as attachments from .attachments import bp as attachments
from .dm_channels import bp as dm_channels
__all__ = ['gateway', 'auth', 'users', 'guilds', 'channels', __all__ = ['gateway', 'auth', 'users', 'guilds', 'channels',
'webhooks', 'science', 'voice', 'invites', 'relationships', 'webhooks', 'science', 'voice', 'invites', 'relationships',
'dms', 'icons', 'nodeinfo', 'static', 'attachments'] 'dms', 'icons', 'nodeinfo', 'static', 'attachments',
'dm_channels']

View File

@ -23,13 +23,14 @@ from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger from logbook import Logger
from litecord.auth import token_check from litecord.auth import token_check
from litecord.enums import ChannelType, GUILD_CHANS from litecord.enums import ChannelType, GUILD_CHANS, MessageType
from litecord.errors import ChannelNotFound from litecord.errors import ChannelNotFound
from litecord.schemas import ( from litecord.schemas import (
validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE
) )
from litecord.blueprints.checks import channel_check, channel_perm_check from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.system_messages import send_sys_message
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('channels', __name__) bp = Blueprint('channels', __name__)
@ -405,31 +406,73 @@ async def _update_voice_channel(channel_id: int, j: dict):
await _common_guild_chan(channel_id, j) await _common_guild_chan(channel_id, j)
async def _update_group_dm(channel_id: int, j: dict, author_id: int):
if 'name' in j:
await app.db.execute("""
UPDATE group_dm_channels
SET name = $1
WHERE id = $2
""", j['name'], channel_id)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id
)
if 'icon' in j:
new_icon = await app.icons.update(
'channel-icons', channel_id, j['icon'], always_icon=True
)
await app.db.execute("""
UPDATE group_dm_channels
SET icon = $1
WHERE id = $2
""", new_icon.icon_hash, channel_id)
await send_sys_message(
app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id
)
@bp.route('/<int:channel_id>', methods=['PUT', 'PATCH']) @bp.route('/<int:channel_id>', methods=['PUT', 'PATCH'])
async def update_channel(channel_id): async def update_channel(channel_id):
"""Update a channel's information""" """Update a channel's information"""
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
if ctype not in GUILD_CHANS: if ctype not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE,
raise ChannelNotFound('Can not edit non-guild channels.') ChannelType.GROUP_DM):
raise ChannelNotFound('unable to edit unsupported chan type')
await channel_perm_check(user_id, channel_id, 'manage_channels') is_guild = ctype in GUILD_CHANS
j = validate(await request.get_json(), CHAN_UPDATE)
# TODO: categories? if is_guild:
await channel_perm_check(user_id, channel_id, 'manage_channels')
j = validate(await request.get_json(),
CHAN_UPDATE if is_guild else GROUP_DM_UPDATE)
# TODO: categories
update_handler = { update_handler = {
ChannelType.GUILD_TEXT: _update_text_channel, ChannelType.GUILD_TEXT: _update_text_channel,
ChannelType.GUILD_VOICE: _update_voice_channel, ChannelType.GUILD_VOICE: _update_voice_channel,
ChannelType.GROUP_DM: _update_group_dm,
}[ctype] }[ctype]
await _update_channel_common(channel_id, guild_id, j) if is_guild:
await update_handler(channel_id, j) await _update_channel_common(channel_id, guild_id, j)
await update_handler(channel_id, j, user_id)
chan = await app.storage.get_channel(channel_id) chan = await app.storage.get_channel(channel_id)
if is_guild:
await app.dispatcher.dispatch(
'guild', guild_id, 'CHANNEL_UPDATE', chan)
else:
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_UPDATE', chan)
await app.dispatcher.dispatch('guild', guild_id, 'CHANNEL_UPDATE', chan)
return jsonify(chan) return jsonify(chan)

View File

@ -17,6 +17,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import Union, List
from quart import current_app as app from quart import current_app as app
from litecord.enums import ChannelType, GUILD_CHANS from litecord.enums import ChannelType, GUILD_CHANS
@ -53,16 +55,23 @@ async def guild_owner_check(user_id: int, guild_id: int):
raise Forbidden('You are not the owner of the guild') raise Forbidden('You are not the owner of the guild')
async def channel_check(user_id, channel_id): async def channel_check(user_id, channel_id, *,
only: Union[ChannelType, List[ChannelType]] = None):
"""Check if the current user is authorized """Check if the current user is authorized
to read the channel's information.""" to read the channel's information."""
chan_type = await app.storage.get_chan_type(channel_id) chan_type = await app.storage.get_chan_type(channel_id)
if chan_type is None: if chan_type is None:
raise ChannelNotFound(f'channel type not found') raise ChannelNotFound('channel type not found')
ctype = ChannelType(chan_type) ctype = ChannelType(chan_type)
if only and not isinstance(only, list):
only = [only]
if only and ctype not in only:
raise ChannelNotFound('invalid channel type')
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
guild_id = await app.db.fetchval(""" guild_id = await app.db.fetchval("""
SELECT guild_id SELECT guild_id
@ -77,6 +86,15 @@ async def channel_check(user_id, channel_id):
peer_id = await app.storage.get_dm_peer(channel_id, user_id) peer_id = await app.storage.get_dm_peer(channel_id, user_id)
return ctype, peer_id return ctype, peer_id
if ctype == ChannelType.GROUP_DM:
owner_id = await app.db.fetchval("""
SELECT owner_id
FROM group_dm_channels
WHERE id = $1
""", channel_id)
return ctype, owner_id
async def guild_perm_check(user_id, guild_id, permission: str): async def guild_perm_check(user_id, guild_id, permission: str):
"""Check guild permissions for a user.""" """Check guild permissions for a user."""

View File

@ -0,0 +1,182 @@
"""
Litecord
Copyright (C) 2018-2019 Luna Mendes
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from quart import Blueprint, current_app as app, jsonify
from logbook import Logger
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check
from litecord.enums import ChannelType, MessageType
from litecord.errors import BadRequest, Forbidden
from litecord.snowflake import get_snowflake
from litecord.system_messages import send_sys_message
from litecord.pubsub.channel import gdm_recipient_view
log = Logger(__name__)
bp = Blueprint('dm_channels', __name__)
async def _raw_gdm_add(channel_id, user_id):
await app.db.execute("""
INSERT INTO group_dm_members (id, member_id)
VALUES ($1, $2)
""", channel_id, user_id)
async def _raw_gdm_remove(channel_id, user_id):
await app.db.execute("""
DELETE FROM group_dm_members
WHERE id = $1 AND member_id = $2
""", channel_id, user_id)
async def _gdm_create(user_id, peer_id) -> int:
"""Create a group dm, given two users.
Returns the new GDM id.
"""
channel_id = get_snowflake()
await app.db.execute("""
INSERT INTO channels (id, channel_type)
VALUES ($1, $2)
""", channel_id, ChannelType.GROUP_DM.value)
await app.db.execute("""
INSERT INTO group_dm_channels (id, owner_id, name, icon)
VALUES ($1, $2, NULL, NULL)
""", channel_id, user_id)
await _raw_gdm_add(channel_id, user_id)
await _raw_gdm_add(channel_id, peer_id)
await app.dispatcher.sub('channel', channel_id, user_id)
await app.dispatcher.sub('channel', channel_id, peer_id)
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch('channel', channel_id, 'CHANNEL_CREATE', chan)
return channel_id
async def _gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None):
"""Add a recipient to a Group DM.
Dispatches:
- A system message with the join (depending of user_id)
- A CHANNEL_CREATE to the peer.
- A CHANNEL_UPDATE to all remaining recipients.
"""
await _raw_gdm_add(channel_id, peer_id)
chan = await app.storage.get_channel(channel_id)
# the reasoning behind gdm_recipient_view is in its docstring.
await app.dispatcher.dispatch(
'user', peer_id, 'CHANNEL_CREATE', gdm_recipient_view(chan, peer_id))
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_UPDATE', chan)
await app.dispatcher.sub('channel', peer_id)
if user_id:
await send_sys_message(
app, channel_id, MessageType.RECIPIENT_ADD,
user_id, peer_id
)
async def _gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None):
"""Remove a member from a GDM.
Dispatches:
- A system message with the leave or forced removal (depending if user_id)
exists or not.
- A CHANNEL_DELETE to the peer.
- A CHANNEL_UPDATE to all remaining recipients.
"""
await _raw_gdm_remove(channel_id, peer_id)
chan = await app.storage.get_channel(channel_id)
await app.dispatcher.dispatch(
'user', peer_id, 'CHANNEL_DELETE', gdm_recipient_view(chan, user_id))
await app.dispatcher.unsub('channel', peer_id)
await app.dispatcher.dispatch(
'channel', channel_id, 'CHANNEL_RECIPIENT_REMOVE', {
'channel_id': str(channel_id),
'user': await app.storage.get_user(peer_id)
}
)
if user_id:
await send_sys_message(
app, channel_id, MessageType.RECIPIENT_REMOVE,
user_id, peer_id
)
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['PUT'])
async def add_to_group_dm(dm_chan, peer_id):
"""Adds a member to a group dm OR creates a group dm."""
user_id = await token_check()
# other_id is the owner of the group dm (gdm) if the
# given channel is a gdm
# other_id is the peer of the dm if the given channel is a dm
ctype, other_id = await channel_check(
user_id, dm_chan,
only=[ChannelType.DM, ChannelType.GROUP_DM]
)
# check relationship with the given user id
# and the user id making the request
friends = await app.user_storage.are_friends_with(user_id, peer_id)
if not friends:
raise BadRequest('Cant insert peer into dm')
if ctype == ChannelType.DM:
dm_chan = await _gdm_create(
user_id, other_id
)
await _gdm_add_recipient(dm_chan, peer_id, user_id=user_id)
return jsonify(
await app.storage.get_channel(dm_chan)
)
@bp.route('/<int:dm_chan>/recipients/<int:peer_id>', methods=['DELETE'])
async def remove_from_group_dm(dm_chan, peer_id):
"""Remove users from group dm."""
user_id = await token_check()
_ctype, owner_id = await channel_check(
user_id, dm_chan, only=ChannelType.GROUP_DM
)
if owner_id != user_id:
raise Forbidden('You are now the owner of the group DM')
await _gdm_remove_recipient(dm_chan, peer_id)
return '', 204

View File

@ -76,3 +76,9 @@ async def _get_user_avatar(user_id, avatar_file):
# @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>') # @bp.route('/app-icons/<int:application_id>/<icon_hash>.<ext>')
async def get_app_icon(application_id, icon_hash, ext): async def get_app_icon(application_id, icon_hash, ext):
pass pass
@bp.route('/channel-icons/<int:channel_id>/<icon_file>', methods=['GET'])
async def _get_gdm_icon(guild_id: int, icon_file: str):
icon_hash, ext = splitext_(icon_file)
return await send_icon('channel-icons', guild_id, icon_hash, ext=ext)

View File

@ -384,9 +384,10 @@ async def get_profile(peer_id: int):
return '', 404 return '', 404
mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id) mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id)
friends = await app.user_storage.are_friends_with(user_id, peer_id)
# don't return a proper card if no guilds are being shared. # don't return a proper card if no guilds are being shared.
if not mutuals: if not mutuals and not friends:
return '', 404 return '', 404
# actual premium status is determined by that # actual premium status is determined by that

View File

@ -103,6 +103,15 @@ class EventDispatcher:
for key in keys: for key in keys:
await self.subscribe(backend_str, key, identifier) await self.subscribe(backend_str, key, identifier)
async def mass_sub(self, identifier: Any,
backends: List[tuple]):
"""Mass subscribe to many backends at once."""
for backend_str, keys in backends:
log.debug('subscribing {} to {} keys in backend {}',
identifier, len(keys), backend_str)
await self.sub_many(backend_str, identifier, keys)
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs): async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
"""Dispatch an event to the backend. """Dispatch an event to the backend.

View File

@ -77,14 +77,14 @@ class LitecordError(Exception):
@property @property
def message(self) -> str: def message(self) -> str:
"""Get an error's message string.""" """Get an error's message string."""
err_code = getattr(self, 'error_code', None)
if err_code is not None:
return ERR_MSG_MAP.get(err_code) or self.args[0]
try: try:
return self.args[0] return self.args[0]
except IndexError: except IndexError:
err_code = getattr(self, 'error_code', None)
if err_code is not None:
return ERR_MSG_MAP.get(err_code) or self.args[0]
return repr(self) return repr(self)
@property @property

View File

@ -372,11 +372,16 @@ class GatewayWebsocket:
# user, fetch info # user, fetch info
uready = await self._user_ready() uready = await self._user_ready()
private_channels = (
await self.user_storage.get_dms(user_id) +
await self.user_storage.get_gdms(user_id)
)
await self.dispatch('READY', {**{ await self.dispatch('READY', {**{
'v': 6, 'v': 6,
'user': user, 'user': user,
'private_channels': await self.user_storage.get_dms(user_id), 'private_channels': private_channels,
'guilds': guilds, 'guilds': guilds,
'session_id': self.state.session_id, 'session_id': self.state.session_id,
@ -437,17 +442,24 @@ class GatewayWebsocket:
by GuildDispatcher.sub by GuildDispatcher.sub
""" """
user_id = self.state.user_id user_id = self.state.user_id
guild_ids = await self._guild_ids() guild_ids = await self._guild_ids()
log.info('subscribing to {} guilds', len(guild_ids))
await self.ext.dispatcher.sub_many('guild', user_id, guild_ids)
# subscribe the user to all dms they have OPENED. # subscribe the user to all dms they have OPENED.
dms = await self.user_storage.get_dms(user_id) dms = await self.user_storage.get_dms(user_id)
dm_ids = [int(dm['id']) for dm in dms] dm_ids = [int(dm['id']) for dm in dms]
# fetch all group dms the user is a member of.
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
log.info('subscribing to {} guilds', len(guild_ids))
log.info('subscribing to {} dms', len(dm_ids)) log.info('subscribing to {} dms', len(dm_ids))
await self.ext.dispatcher.sub_many('channel', user_id, dm_ids) log.info('subscribing to {} group dms', len(gdm_ids))
await self.ext.dispatcher.mass_sub(user_id, [
('guild', guild_ids),
('channel', dm_ids),
('channel', gdm_ids)
])
if not self.state.bot: if not self.state.bot:
# subscribe to all friends # subscribe to all friends

View File

@ -161,21 +161,18 @@ def parse_data_uri(string) -> tuple:
def _gen_update_sql(scope: str) -> str: def _gen_update_sql(scope: str) -> str:
field = { field = {
'user': 'avatar', 'user': 'avatar',
'guild': 'icon' 'guild': 'icon',
'channel-icons': 'icon',
}[scope] }[scope]
table = { table = {
'user': 'users', 'user': 'users',
'guild': 'guilds' 'guild': 'guilds',
}[scope] 'channel-icons': 'group_dm_channels'
col = {
'user': 'id',
'guild': 'id'
}[scope] }[scope]
return f""" return f"""
SELECT {field} FROM {table} WHERE {col} = $1 SELECT {field} FROM {table} WHERE id = $1
""" """
@ -277,6 +274,9 @@ class IconManager:
async def generic_get(self, scope, key, icon_hash, **kwargs) -> Icon: async def generic_get(self, scope, key, icon_hash, **kwargs) -> Icon:
"""Get any icon.""" """Get any icon."""
if icon_hash is None:
return None
log.debug('GET {} {} {}', scope, key, icon_hash) log.debug('GET {} {} {}', scope, key, icon_hash)
key = str(key) key = str(key)
@ -409,6 +409,12 @@ class IconManager:
WHERE icon = $1 WHERE icon = $1
""", icon.icon_hash) """, icon.icon_hash)
await self.storage.db.execute("""
UPDATE group_dm_channels
SET icon = NULL
WHERE icon = $1
""", icon.icon_hash)
await self.storage.db.execute(""" await self.storage.db.execute("""
DELETE FROM icons DELETE FROM icons
WHERE hash = $1 WHERE hash = $1

View File

@ -22,10 +22,32 @@ from typing import Any
from logbook import Logger from logbook import Logger
from .dispatcher import DispatcherWithState from .dispatcher import DispatcherWithState
from litecord.enums import ChannelType
from litecord.utils import index_by_func
log = Logger(__name__) log = Logger(__name__)
def gdm_recipient_view(orig: dict, user_id: int) -> dict:
"""Create a copy of the original channel object that doesn't
show the user we are dispatching it to.
this only applies to group dms and discords' api design that says
a group dms' recipients must not show the original user.
"""
# make a copy or the original channel object
data = dict(orig)
idx = index_by_func(
lambda user: user['id'] == str(user_id),
data['recipients']
)
data['recipients'].pop(idx)
return data
class ChannelDispatcher(DispatcherWithState): class ChannelDispatcher(DispatcherWithState):
"""Main channel Pub/Sub logic.""" """Main channel Pub/Sub logic."""
KEY_TYPE = int KEY_TYPE = int
@ -62,7 +84,19 @@ class ChannelDispatcher(DispatcherWithState):
await self.unsub(channel_id, user_id) await self.unsub(channel_id, user_id)
continue continue
cur_sess = await self._dispatch_states(states, event, data) cur_sess = 0
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
and data.get('type') == ChannelType.GROUP_DM.value:
# we edit the channel payload so it doesn't show
# the user as a recipient
new_data = gdm_recipient_view(data, user_id)
cur_sess = await self._dispatch_states(
states, event, new_data)
else:
cur_sess = await self._dispatch_states(
states, event, data)
sessions.extend(cur_sess) sessions.extend(cur_sess)
dispatched += len(cur_sess) dispatched += len(cur_sess)

View File

@ -581,6 +581,14 @@ CREATE_GROUP_DM = {
}, },
} }
GROUP_DM_UPDATE = {
'name': {
'type': 'guild_name',
'required': False
},
'icon': {'type': 'b64_icon', 'required': False, 'nullable': True},
}
SPECIFIC_FRIEND = { SPECIFIC_FRIEND = {
'username': {'type': 'username'}, 'username': {'type': 'username'},
'discriminator': {'type': 'discriminator'} 'discriminator': {'type': 'discriminator'}

View File

@ -353,7 +353,38 @@ class Storage:
return list(map(_overwrite_convert, overwrite_rows)) return list(map(_overwrite_convert, overwrite_rows))
async def get_channel(self, channel_id: int) -> Dict[str, Any]: async def gdm_recipient_ids(self, channel_id: int) -> List[int]:
"""Get the list of user IDs that are recipients of the
given Group DM."""
user_ids = await self.db.fetch("""
SELECT member_id
FROM group_dm_members
JOIN users
ON member_id = users.id
WHERE group_dm_members.id = $1
ORDER BY username DESC
""", channel_id)
return [r['member_id'] for r in user_ids]
async def _gdm_recipients(self, channel_id: int,
reference_id: int = None) -> List[int]:
"""Get the list of users that are recipients of the
given Group DM."""
recipients = await self.gdm_recipient_ids(channel_id)
res = []
for user_id in recipients:
if user_id == reference_id:
continue
res.append(
await self.get_user(user_id)
)
return res
async def get_channel(self, channel_id: int, **kwargs) -> Dict[str, Any]:
"""Fetch a single channel's information.""" """Fetch a single channel's information."""
chan_type = await self.get_chan_type(channel_id) chan_type = await self.get_chan_type(channel_id)
ctype = ChannelType(chan_type) ctype = ChannelType(chan_type)
@ -387,7 +418,8 @@ class Storage:
drow['type'] = chan_type drow['type'] = chan_type
drow['last_message_id'] = await self.chan_last_message_str( drow['last_message_id'] = await self.chan_last_message_str(
channel_id) channel_id
)
# dms have just two recipients. # dms have just two recipients.
drow['recipients'] = [ drow['recipients'] = [
@ -401,8 +433,22 @@ class Storage:
drow['id'] = str(drow['id']) drow['id'] = str(drow['id'])
return drow return drow
elif ctype == ChannelType.GROUP_DM: elif ctype == ChannelType.GROUP_DM:
# TODO: group dms gdm_row = await self.db.fetchrow("""
pass SELECT id::text, owner_id::text, name, icon
FROM group_dm_channels
WHERE id = $1
""", channel_id)
drow = dict(gdm_row)
drow['type'] = chan_type
drow['recipients'] = await self._gdm_recipients(
channel_id, kwargs.get('user_id')
)
drow['last_message_id'] = await self.chan_last_message_str(
channel_id
)
return drow
return None return None

View File

@ -17,11 +17,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from logbook import Logger
from litecord.snowflake import get_snowflake from litecord.snowflake import get_snowflake
from litecord.enums import MessageType from litecord.enums import MessageType
log = Logger(__name__)
async def _handle_pin_msg(app, channel_id, pinned_id, author_id): async def _handle_pin_msg(app, channel_id, _pinned_id, author_id):
"""Handle a message pin.""" """Handle a message pin."""
new_id = get_snowflake() new_id = get_snowflake()
@ -41,12 +44,108 @@ async def _handle_pin_msg(app, channel_id, pinned_id, author_id):
return new_id return new_id
# TODO: decrease repetition between add and remove handlers
async def _handle_recp_add(app, channel_id, author_id, peer_id):
new_id = get_snowflake()
await app.db.execute(
"""
INSERT INTO messages
(id, channel_id, author_id, webhook_id,
content, message_type)
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
f'<@{peer_id}>',
MessageType.RECIPIENT_ADD.value
)
return new_id
async def _handle_recp_rmv(app, channel_id, author_id, peer_id):
new_id = get_snowflake()
await app.db.execute(
"""
INSERT INTO messages
(id, channel_id, author_id, webhook_id,
content, message_type)
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
f'<@{peer_id}>',
MessageType.RECIPIENT_REMOVE.value
)
return new_id
async def _handle_gdm_name_edit(app, channel_id, author_id):
new_id = get_snowflake()
gdm_name = await app.db.fetchval("""
SELECT name FROM group_dm_channels
WHERE id = $1
""", channel_id)
if not gdm_name:
log.warning('no gdm name found for sys message')
return
await app.db.execute(
"""
INSERT INTO messages
(id, channel_id, author_id, webhook_id,
content, message_type)
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
gdm_name,
MessageType.CHANNEL_NAME_CHANGE.value
)
return new_id
async def _handle_gdm_icon_edit(app, channel_id, author_id):
new_id = get_snowflake()
await app.db.execute(
"""
INSERT INTO messages
(id, channel_id, author_id, webhook_id,
content, message_type)
VALUES
($1, $2, $3, NULL, $4, $5)
""",
new_id, channel_id, author_id,
'',
MessageType.CHANNEL_ICON_CHANGE.value
)
return new_id
async def send_sys_message(app, channel_id: int, m_type: MessageType, async def send_sys_message(app, channel_id: int, m_type: MessageType,
*args, **kwargs) -> int: *args, **kwargs) -> int:
"""Send a system message.""" """Send a system message."""
handler = { try:
MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg, handler = {
}[m_type] MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg,
# gdm specific
MessageType.RECIPIENT_ADD: _handle_recp_add,
MessageType.RECIPIENT_REMOVE: _handle_recp_rmv,
MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit,
MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit
}[m_type]
except KeyError:
raise ValueError('Invalid system message type')
message_id = await handler(app, channel_id, *args, **kwargs) message_id = await handler(app, channel_id, *args, **kwargs)

View File

@ -338,3 +338,26 @@ class UserStorage:
) )
) )
""", user_id, peer_id) """, user_id, peer_id)
async def get_gdms_internal(self, user_id) -> List[int]:
"""Return a list of Group DM IDs the user is a member of."""
rows = await self.db.fetch("""
SELECT id
FROM group_dm_members
WHERE member_id = $1
""", user_id)
return [r['id'] for r in rows]
async def get_gdms(self, user_id) -> List[Dict[str, Any]]:
"""Get list of group DMs a user is in."""
gdm_ids = await self.get_gdms_internal(user_id)
res = []
for gdm_id in gdm_ids:
res.append(
await self.storage.get_channel(gdm_id, user_id=user_id)
)
return res

3
run.py
View File

@ -35,7 +35,7 @@ import config
from litecord.blueprints import ( from litecord.blueprints import (
gateway, auth, users, guilds, channels, webhooks, science, gateway, auth, users, guilds, channels, webhooks, science,
voice, invites, relationships, dms, icons, nodeinfo, static, voice, invites, relationships, dms, icons, nodeinfo, static,
attachments attachments, dm_channels
) )
# those blueprints are separated from the "main" ones # those blueprints are separated from the "main" ones
@ -128,6 +128,7 @@ def set_blueprints(app_):
voice: '/voice', voice: '/voice',
invites: None, invites: None,
dms: '/users', dms: '/users',
dm_channels: '/channels',
fake_store: None, fake_store: None,