diff --git a/litecord/blueprints/__init__.py b/litecord/blueprints/__init__.py index a22b5b1..9b021b4 100644 --- a/litecord/blueprints/__init__.py +++ b/litecord/blueprints/__init__.py @@ -32,7 +32,9 @@ from .icons import bp as icons from .nodeinfo import bp as nodeinfo from .static import bp as static from .attachments import bp as attachments +from .dm_channels import bp as dm_channels __all__ = ['gateway', 'auth', 'users', 'guilds', 'channels', 'webhooks', 'science', 'voice', 'invites', 'relationships', - 'dms', 'icons', 'nodeinfo', 'static', 'attachments'] + 'dms', 'icons', 'nodeinfo', 'static', 'attachments', + 'dm_channels'] diff --git a/litecord/blueprints/channels.py b/litecord/blueprints/channels.py index 585d6a6..76324dd 100644 --- a/litecord/blueprints/channels.py +++ b/litecord/blueprints/channels.py @@ -23,13 +23,14 @@ from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger from litecord.auth import token_check -from litecord.enums import ChannelType, GUILD_CHANS +from litecord.enums import ChannelType, GUILD_CHANS, MessageType from litecord.errors import ChannelNotFound from litecord.schemas import ( - validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL + validate, CHAN_UPDATE, CHAN_OVERWRITE, SEARCH_CHANNEL, GROUP_DM_UPDATE ) from litecord.blueprints.checks import channel_check, channel_perm_check +from litecord.system_messages import send_sys_message log = Logger(__name__) bp = Blueprint('channels', __name__) @@ -405,31 +406,73 @@ async def _update_voice_channel(channel_id: int, j: dict): await _common_guild_chan(channel_id, j) +async def _update_group_dm(channel_id: int, j: dict, author_id: int): + if 'name' in j: + await app.db.execute(""" + UPDATE group_dm_channels + SET name = $1 + WHERE id = $2 + """, j['name'], channel_id) + + await send_sys_message( + app, channel_id, MessageType.CHANNEL_NAME_CHANGE, author_id + ) + + if 'icon' in j: + new_icon = await app.icons.update( + 'channel-icons', channel_id, j['icon'], always_icon=True + ) + + await app.db.execute(""" + UPDATE group_dm_channels + SET icon = $1 + WHERE id = $2 + """, new_icon.icon_hash, channel_id) + + await send_sys_message( + app, channel_id, MessageType.CHANNEL_ICON_CHANGE, author_id + ) + + @bp.route('/', methods=['PUT', 'PATCH']) async def update_channel(channel_id): """Update a channel's information""" user_id = await token_check() ctype, guild_id = await channel_check(user_id, channel_id) - if ctype not in GUILD_CHANS: - raise ChannelNotFound('Can not edit non-guild channels.') + if ctype not in (ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE, + ChannelType.GROUP_DM): + raise ChannelNotFound('unable to edit unsupported chan type') - await channel_perm_check(user_id, channel_id, 'manage_channels') - j = validate(await request.get_json(), CHAN_UPDATE) + is_guild = ctype in GUILD_CHANS - # TODO: categories? + if is_guild: + await channel_perm_check(user_id, channel_id, 'manage_channels') + + j = validate(await request.get_json(), + CHAN_UPDATE if is_guild else GROUP_DM_UPDATE) + + # TODO: categories update_handler = { ChannelType.GUILD_TEXT: _update_text_channel, ChannelType.GUILD_VOICE: _update_voice_channel, + ChannelType.GROUP_DM: _update_group_dm, }[ctype] - await _update_channel_common(channel_id, guild_id, j) - await update_handler(channel_id, j) + if is_guild: + await _update_channel_common(channel_id, guild_id, j) + + await update_handler(channel_id, j, user_id) chan = await app.storage.get_channel(channel_id) + if is_guild: + await app.dispatcher.dispatch( + 'guild', guild_id, 'CHANNEL_UPDATE', chan) + else: + await app.dispatcher.dispatch( + 'channel', channel_id, 'CHANNEL_UPDATE', chan) - await app.dispatcher.dispatch('guild', guild_id, 'CHANNEL_UPDATE', chan) return jsonify(chan) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 9c9253a..e897c21 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -17,6 +17,8 @@ along with this program. If not, see . """ +from typing import Union, List + from quart import current_app as app from litecord.enums import ChannelType, GUILD_CHANS @@ -53,16 +55,23 @@ async def guild_owner_check(user_id: int, guild_id: int): raise Forbidden('You are not the owner of the guild') -async def channel_check(user_id, channel_id): +async def channel_check(user_id, channel_id, *, + only: Union[ChannelType, List[ChannelType]] = None): """Check if the current user is authorized to read the channel's information.""" chan_type = await app.storage.get_chan_type(channel_id) if chan_type is None: - raise ChannelNotFound(f'channel type not found') + raise ChannelNotFound('channel type not found') ctype = ChannelType(chan_type) + if only and not isinstance(only, list): + only = [only] + + if only and ctype not in only: + raise ChannelNotFound('invalid channel type') + if ctype in GUILD_CHANS: guild_id = await app.db.fetchval(""" SELECT guild_id @@ -77,6 +86,15 @@ async def channel_check(user_id, channel_id): peer_id = await app.storage.get_dm_peer(channel_id, user_id) return ctype, peer_id + if ctype == ChannelType.GROUP_DM: + owner_id = await app.db.fetchval(""" + SELECT owner_id + FROM group_dm_channels + WHERE id = $1 + """, channel_id) + + return ctype, owner_id + async def guild_perm_check(user_id, guild_id, permission: str): """Check guild permissions for a user.""" diff --git a/litecord/blueprints/dm_channels.py b/litecord/blueprints/dm_channels.py new file mode 100644 index 0000000..a79ed8b --- /dev/null +++ b/litecord/blueprints/dm_channels.py @@ -0,0 +1,182 @@ +""" + +Litecord +Copyright (C) 2018-2019 Luna Mendes + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . + +""" + +from quart import Blueprint, current_app as app, jsonify +from logbook import Logger + +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import channel_check +from litecord.enums import ChannelType, MessageType +from litecord.errors import BadRequest, Forbidden +from litecord.snowflake import get_snowflake +from litecord.system_messages import send_sys_message +from litecord.pubsub.channel import gdm_recipient_view + +log = Logger(__name__) +bp = Blueprint('dm_channels', __name__) + + +async def _raw_gdm_add(channel_id, user_id): + await app.db.execute(""" + INSERT INTO group_dm_members (id, member_id) + VALUES ($1, $2) + """, channel_id, user_id) + + +async def _raw_gdm_remove(channel_id, user_id): + await app.db.execute(""" + DELETE FROM group_dm_members + WHERE id = $1 AND member_id = $2 + """, channel_id, user_id) + + +async def _gdm_create(user_id, peer_id) -> int: + """Create a group dm, given two users. + + Returns the new GDM id. + """ + channel_id = get_snowflake() + + await app.db.execute(""" + INSERT INTO channels (id, channel_type) + VALUES ($1, $2) + """, channel_id, ChannelType.GROUP_DM.value) + + await app.db.execute(""" + INSERT INTO group_dm_channels (id, owner_id, name, icon) + VALUES ($1, $2, NULL, NULL) + """, channel_id, user_id) + + await _raw_gdm_add(channel_id, user_id) + await _raw_gdm_add(channel_id, peer_id) + + await app.dispatcher.sub('channel', channel_id, user_id) + await app.dispatcher.sub('channel', channel_id, peer_id) + + chan = await app.storage.get_channel(channel_id) + await app.dispatcher.dispatch('channel', channel_id, 'CHANNEL_CREATE', chan) + + return channel_id + + +async def _gdm_add_recipient(channel_id: int, peer_id: int, *, user_id=None): + """Add a recipient to a Group DM. + + Dispatches: + - A system message with the join (depending of user_id) + - A CHANNEL_CREATE to the peer. + - A CHANNEL_UPDATE to all remaining recipients. + """ + await _raw_gdm_add(channel_id, peer_id) + + chan = await app.storage.get_channel(channel_id) + + # the reasoning behind gdm_recipient_view is in its docstring. + await app.dispatcher.dispatch( + 'user', peer_id, 'CHANNEL_CREATE', gdm_recipient_view(chan, peer_id)) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'CHANNEL_UPDATE', chan) + + await app.dispatcher.sub('channel', peer_id) + + if user_id: + await send_sys_message( + app, channel_id, MessageType.RECIPIENT_ADD, + user_id, peer_id + ) + + +async def _gdm_remove_recipient(channel_id: int, peer_id: int, *, user_id=None): + """Remove a member from a GDM. + + Dispatches: + - A system message with the leave or forced removal (depending if user_id) + exists or not. + - A CHANNEL_DELETE to the peer. + - A CHANNEL_UPDATE to all remaining recipients. + """ + await _raw_gdm_remove(channel_id, peer_id) + + chan = await app.storage.get_channel(channel_id) + await app.dispatcher.dispatch( + 'user', peer_id, 'CHANNEL_DELETE', gdm_recipient_view(chan, user_id)) + + await app.dispatcher.unsub('channel', peer_id) + + await app.dispatcher.dispatch( + 'channel', channel_id, 'CHANNEL_RECIPIENT_REMOVE', { + 'channel_id': str(channel_id), + 'user': await app.storage.get_user(peer_id) + } + ) + + if user_id: + await send_sys_message( + app, channel_id, MessageType.RECIPIENT_REMOVE, + user_id, peer_id + ) + + +@bp.route('//recipients/', methods=['PUT']) +async def add_to_group_dm(dm_chan, peer_id): + """Adds a member to a group dm OR creates a group dm.""" + user_id = await token_check() + + # other_id is the owner of the group dm (gdm) if the + # given channel is a gdm + + # other_id is the peer of the dm if the given channel is a dm + ctype, other_id = await channel_check( + user_id, dm_chan, + only=[ChannelType.DM, ChannelType.GROUP_DM] + ) + + # check relationship with the given user id + # and the user id making the request + friends = await app.user_storage.are_friends_with(user_id, peer_id) + + if not friends: + raise BadRequest('Cant insert peer into dm') + + if ctype == ChannelType.DM: + dm_chan = await _gdm_create( + user_id, other_id + ) + + await _gdm_add_recipient(dm_chan, peer_id, user_id=user_id) + + return jsonify( + await app.storage.get_channel(dm_chan) + ) + + +@bp.route('//recipients/', methods=['DELETE']) +async def remove_from_group_dm(dm_chan, peer_id): + """Remove users from group dm.""" + user_id = await token_check() + _ctype, owner_id = await channel_check( + user_id, dm_chan, only=ChannelType.GROUP_DM + ) + + if owner_id != user_id: + raise Forbidden('You are now the owner of the group DM') + + await _gdm_remove_recipient(dm_chan, peer_id) + return '', 204 diff --git a/litecord/blueprints/icons.py b/litecord/blueprints/icons.py index 86662b2..9daa45b 100644 --- a/litecord/blueprints/icons.py +++ b/litecord/blueprints/icons.py @@ -76,3 +76,9 @@ async def _get_user_avatar(user_id, avatar_file): # @bp.route('/app-icons//.') async def get_app_icon(application_id, icon_hash, ext): pass + + +@bp.route('/channel-icons//', methods=['GET']) +async def _get_gdm_icon(guild_id: int, icon_file: str): + icon_hash, ext = splitext_(icon_file) + return await send_icon('channel-icons', guild_id, icon_hash, ext=ext) diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 969a043..8e9deb7 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -384,9 +384,10 @@ async def get_profile(peer_id: int): return '', 404 mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id) + friends = await app.user_storage.are_friends_with(user_id, peer_id) # don't return a proper card if no guilds are being shared. - if not mutuals: + if not mutuals and not friends: return '', 404 # actual premium status is determined by that diff --git a/litecord/dispatcher.py b/litecord/dispatcher.py index 59e19a3..8d38303 100644 --- a/litecord/dispatcher.py +++ b/litecord/dispatcher.py @@ -103,6 +103,15 @@ class EventDispatcher: for key in keys: await self.subscribe(backend_str, key, identifier) + async def mass_sub(self, identifier: Any, + backends: List[tuple]): + """Mass subscribe to many backends at once.""" + for backend_str, keys in backends: + log.debug('subscribing {} to {} keys in backend {}', + identifier, len(keys), backend_str) + + await self.sub_many(backend_str, identifier, keys) + async def dispatch(self, backend_str: str, key: Any, *args, **kwargs): """Dispatch an event to the backend. diff --git a/litecord/errors.py b/litecord/errors.py index 25e7baa..1e9942a 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -77,14 +77,14 @@ class LitecordError(Exception): @property def message(self) -> str: """Get an error's message string.""" - err_code = getattr(self, 'error_code', None) - - if err_code is not None: - return ERR_MSG_MAP.get(err_code) or self.args[0] - try: return self.args[0] except IndexError: + err_code = getattr(self, 'error_code', None) + + if err_code is not None: + return ERR_MSG_MAP.get(err_code) or self.args[0] + return repr(self) @property diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 9dc402c..ca2d444 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -372,11 +372,16 @@ class GatewayWebsocket: # user, fetch info uready = await self._user_ready() + private_channels = ( + await self.user_storage.get_dms(user_id) + + await self.user_storage.get_gdms(user_id) + ) + await self.dispatch('READY', {**{ 'v': 6, 'user': user, - 'private_channels': await self.user_storage.get_dms(user_id), + 'private_channels': private_channels, 'guilds': guilds, 'session_id': self.state.session_id, @@ -437,17 +442,24 @@ class GatewayWebsocket: by GuildDispatcher.sub """ user_id = self.state.user_id - guild_ids = await self._guild_ids() - log.info('subscribing to {} guilds', len(guild_ids)) - await self.ext.dispatcher.sub_many('guild', user_id, guild_ids) # subscribe the user to all dms they have OPENED. dms = await self.user_storage.get_dms(user_id) dm_ids = [int(dm['id']) for dm in dms] + # fetch all group dms the user is a member of. + gdm_ids = await self.user_storage.get_gdms_internal(user_id) + + log.info('subscribing to {} guilds', len(guild_ids)) log.info('subscribing to {} dms', len(dm_ids)) - await self.ext.dispatcher.sub_many('channel', user_id, dm_ids) + log.info('subscribing to {} group dms', len(gdm_ids)) + + await self.ext.dispatcher.mass_sub(user_id, [ + ('guild', guild_ids), + ('channel', dm_ids), + ('channel', gdm_ids) + ]) if not self.state.bot: # subscribe to all friends diff --git a/litecord/images.py b/litecord/images.py index bf3434c..854e2b9 100644 --- a/litecord/images.py +++ b/litecord/images.py @@ -161,21 +161,18 @@ def parse_data_uri(string) -> tuple: def _gen_update_sql(scope: str) -> str: field = { 'user': 'avatar', - 'guild': 'icon' + 'guild': 'icon', + 'channel-icons': 'icon', }[scope] table = { 'user': 'users', - 'guild': 'guilds' - }[scope] - - col = { - 'user': 'id', - 'guild': 'id' + 'guild': 'guilds', + 'channel-icons': 'group_dm_channels' }[scope] return f""" - SELECT {field} FROM {table} WHERE {col} = $1 + SELECT {field} FROM {table} WHERE id = $1 """ @@ -277,6 +274,9 @@ class IconManager: async def generic_get(self, scope, key, icon_hash, **kwargs) -> Icon: """Get any icon.""" + if icon_hash is None: + return None + log.debug('GET {} {} {}', scope, key, icon_hash) key = str(key) @@ -409,6 +409,12 @@ class IconManager: WHERE icon = $1 """, icon.icon_hash) + await self.storage.db.execute(""" + UPDATE group_dm_channels + SET icon = NULL + WHERE icon = $1 + """, icon.icon_hash) + await self.storage.db.execute(""" DELETE FROM icons WHERE hash = $1 diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index ef23a6a..2b32a1d 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -22,10 +22,32 @@ from typing import Any from logbook import Logger from .dispatcher import DispatcherWithState +from litecord.enums import ChannelType +from litecord.utils import index_by_func log = Logger(__name__) +def gdm_recipient_view(orig: dict, user_id: int) -> dict: + """Create a copy of the original channel object that doesn't + show the user we are dispatching it to. + + this only applies to group dms and discords' api design that says + a group dms' recipients must not show the original user. + """ + # make a copy or the original channel object + data = dict(orig) + + idx = index_by_func( + lambda user: user['id'] == str(user_id), + data['recipients'] + ) + + data['recipients'].pop(idx) + + return data + + class ChannelDispatcher(DispatcherWithState): """Main channel Pub/Sub logic.""" KEY_TYPE = int @@ -62,7 +84,19 @@ class ChannelDispatcher(DispatcherWithState): await self.unsub(channel_id, user_id) continue - cur_sess = await self._dispatch_states(states, event, data) + cur_sess = 0 + + if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \ + and data.get('type') == ChannelType.GROUP_DM.value: + # we edit the channel payload so it doesn't show + # the user as a recipient + + new_data = gdm_recipient_view(data, user_id) + cur_sess = await self._dispatch_states( + states, event, new_data) + else: + cur_sess = await self._dispatch_states( + states, event, data) sessions.extend(cur_sess) dispatched += len(cur_sess) diff --git a/litecord/schemas.py b/litecord/schemas.py index e336eba..c39c46d 100644 --- a/litecord/schemas.py +++ b/litecord/schemas.py @@ -581,6 +581,14 @@ CREATE_GROUP_DM = { }, } +GROUP_DM_UPDATE = { + 'name': { + 'type': 'guild_name', + 'required': False + }, + 'icon': {'type': 'b64_icon', 'required': False, 'nullable': True}, +} + SPECIFIC_FRIEND = { 'username': {'type': 'username'}, 'discriminator': {'type': 'discriminator'} diff --git a/litecord/storage.py b/litecord/storage.py index 3f96058..4554589 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -353,7 +353,38 @@ class Storage: return list(map(_overwrite_convert, overwrite_rows)) - async def get_channel(self, channel_id: int) -> Dict[str, Any]: + async def gdm_recipient_ids(self, channel_id: int) -> List[int]: + """Get the list of user IDs that are recipients of the + given Group DM.""" + user_ids = await self.db.fetch(""" + SELECT member_id + FROM group_dm_members + JOIN users + ON member_id = users.id + WHERE group_dm_members.id = $1 + ORDER BY username DESC + """, channel_id) + + return [r['member_id'] for r in user_ids] + + async def _gdm_recipients(self, channel_id: int, + reference_id: int = None) -> List[int]: + """Get the list of users that are recipients of the + given Group DM.""" + recipients = await self.gdm_recipient_ids(channel_id) + res = [] + + for user_id in recipients: + if user_id == reference_id: + continue + + res.append( + await self.get_user(user_id) + ) + + return res + + async def get_channel(self, channel_id: int, **kwargs) -> Dict[str, Any]: """Fetch a single channel's information.""" chan_type = await self.get_chan_type(channel_id) ctype = ChannelType(chan_type) @@ -387,7 +418,8 @@ class Storage: drow['type'] = chan_type drow['last_message_id'] = await self.chan_last_message_str( - channel_id) + channel_id + ) # dms have just two recipients. drow['recipients'] = [ @@ -401,8 +433,22 @@ class Storage: drow['id'] = str(drow['id']) return drow elif ctype == ChannelType.GROUP_DM: - # TODO: group dms - pass + gdm_row = await self.db.fetchrow(""" + SELECT id::text, owner_id::text, name, icon + FROM group_dm_channels + WHERE id = $1 + """, channel_id) + + drow = dict(gdm_row) + drow['type'] = chan_type + drow['recipients'] = await self._gdm_recipients( + channel_id, kwargs.get('user_id') + ) + drow['last_message_id'] = await self.chan_last_message_str( + channel_id + ) + + return drow return None diff --git a/litecord/system_messages.py b/litecord/system_messages.py index 33a10a0..0f13b8e 100644 --- a/litecord/system_messages.py +++ b/litecord/system_messages.py @@ -17,11 +17,14 @@ along with this program. If not, see . """ +from logbook import Logger + from litecord.snowflake import get_snowflake from litecord.enums import MessageType +log = Logger(__name__) -async def _handle_pin_msg(app, channel_id, pinned_id, author_id): +async def _handle_pin_msg(app, channel_id, _pinned_id, author_id): """Handle a message pin.""" new_id = get_snowflake() @@ -41,12 +44,108 @@ async def _handle_pin_msg(app, channel_id, pinned_id, author_id): return new_id +# TODO: decrease repetition between add and remove handlers +async def _handle_recp_add(app, channel_id, author_id, peer_id): + new_id = get_snowflake() + + await app.db.execute( + """ + INSERT INTO messages + (id, channel_id, author_id, webhook_id, + content, message_type) + VALUES + ($1, $2, $3, NULL, $4, $5) + """, + new_id, channel_id, author_id, + f'<@{peer_id}>', + MessageType.RECIPIENT_ADD.value + ) + + return new_id + + + +async def _handle_recp_rmv(app, channel_id, author_id, peer_id): + new_id = get_snowflake() + + await app.db.execute( + """ + INSERT INTO messages + (id, channel_id, author_id, webhook_id, + content, message_type) + VALUES + ($1, $2, $3, NULL, $4, $5) + """, + new_id, channel_id, author_id, + f'<@{peer_id}>', + MessageType.RECIPIENT_REMOVE.value + ) + + return new_id + + +async def _handle_gdm_name_edit(app, channel_id, author_id): + new_id = get_snowflake() + + gdm_name = await app.db.fetchval(""" + SELECT name FROM group_dm_channels + WHERE id = $1 + """, channel_id) + + if not gdm_name: + log.warning('no gdm name found for sys message') + return + + await app.db.execute( + """ + INSERT INTO messages + (id, channel_id, author_id, webhook_id, + content, message_type) + VALUES + ($1, $2, $3, NULL, $4, $5) + """, + new_id, channel_id, author_id, + gdm_name, + MessageType.CHANNEL_NAME_CHANGE.value + ) + + return new_id + + +async def _handle_gdm_icon_edit(app, channel_id, author_id): + new_id = get_snowflake() + + await app.db.execute( + """ + INSERT INTO messages + (id, channel_id, author_id, webhook_id, + content, message_type) + VALUES + ($1, $2, $3, NULL, $4, $5) + """, + new_id, channel_id, author_id, + '', + MessageType.CHANNEL_ICON_CHANGE.value + ) + + return new_id + + async def send_sys_message(app, channel_id: int, m_type: MessageType, *args, **kwargs) -> int: """Send a system message.""" - handler = { - MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg, - }[m_type] + try: + handler = { + MessageType.CHANNEL_PINNED_MESSAGE: _handle_pin_msg, + + # gdm specific + MessageType.RECIPIENT_ADD: _handle_recp_add, + MessageType.RECIPIENT_REMOVE: _handle_recp_rmv, + MessageType.CHANNEL_NAME_CHANGE: _handle_gdm_name_edit, + MessageType.CHANNEL_ICON_CHANGE: _handle_gdm_icon_edit + }[m_type] + except KeyError: + raise ValueError('Invalid system message type') message_id = await handler(app, channel_id, *args, **kwargs) diff --git a/litecord/user_storage.py b/litecord/user_storage.py index 972e16a..baaeb2a 100644 --- a/litecord/user_storage.py +++ b/litecord/user_storage.py @@ -338,3 +338,26 @@ class UserStorage: ) ) """, user_id, peer_id) + + async def get_gdms_internal(self, user_id) -> List[int]: + """Return a list of Group DM IDs the user is a member of.""" + rows = await self.db.fetch(""" + SELECT id + FROM group_dm_members + WHERE member_id = $1 + """, user_id) + + return [r['id'] for r in rows] + + async def get_gdms(self, user_id) -> List[Dict[str, Any]]: + """Get list of group DMs a user is in.""" + gdm_ids = await self.get_gdms_internal(user_id) + + res = [] + + for gdm_id in gdm_ids: + res.append( + await self.storage.get_channel(gdm_id, user_id=user_id) + ) + + return res diff --git a/run.py b/run.py index f9ffde2..ddebbe8 100644 --- a/run.py +++ b/run.py @@ -35,7 +35,7 @@ import config from litecord.blueprints import ( gateway, auth, users, guilds, channels, webhooks, science, voice, invites, relationships, dms, icons, nodeinfo, static, - attachments + attachments, dm_channels ) # those blueprints are separated from the "main" ones @@ -128,6 +128,7 @@ def set_blueprints(app_): voice: '/voice', invites: None, dms: '/users', + dm_channels: '/channels', fake_store: None,