From 72cbd8017bb4ce7ccac515adcf5957e91f01e78b Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 8 Feb 2019 19:00:38 -0300 Subject: [PATCH] dm_channels: add basic checks and dummy calls to _gdm prefixed funcs - checks: return group dm owner when channel_check is given a group dm - users: check for friendship on profile --- litecord/blueprints/checks.py | 8 +++++- litecord/blueprints/dm_channels.py | 40 ++++++++++++++++++++++++++---- litecord/blueprints/users.py | 3 ++- litecord/storage.py | 40 ++++++++++++++++++++++++++++-- 4 files changed, 82 insertions(+), 9 deletions(-) diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index a27fcb4..5a6846c 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -87,7 +87,13 @@ async def channel_check(user_id, channel_id, *, return ctype, peer_id if ctype == ChannelType.GROUP_DM: - return ctype + owner_id = await app.db.fetchval(""" + SELECT owner_id + FROM group_dm_channels + WHERE id = $1 + """, channel_id) + + return ctype, owner_id async def guild_perm_check(user_id, guild_id, permission: str): diff --git a/litecord/blueprints/dm_channels.py b/litecord/blueprints/dm_channels.py index 310150b..3b395a1 100644 --- a/litecord/blueprints/dm_channels.py +++ b/litecord/blueprints/dm_channels.py @@ -17,31 +17,61 @@ along with this program. If not, see . """ -from quart import Blueprint, request, current_app as app, jsonify +from quart import Blueprint, current_app as app, jsonify from logbook import Logger from litecord.blueprints.auth import token_check from litecord.blueprints.checks import channel_check from litecord.enums import ChannelType +from litecord.errors import BadRequest, Forbidden log = Logger(__name__) bp = Blueprint('dm_channels', __name__) -@bp.route('//receipients/', methods=['PUT']) -async def add_to_group_dm(dm_chan, user_id): +@bp.route('//receipients/', methods=['PUT']) +async def add_to_group_dm(dm_chan, peer_id): """Adds a member to a group dm OR creates a group dm.""" user_id = await token_check() - ctype = await channel_check( + + # other_id is the owner of the group dm (gdm) if the + # given channel is a gdm + + # other_id is the peer of the dm if the given channel is a dm + ctype, other_id = await channel_check( user_id, dm_chan, only=[ChannelType.DM, ChannelType.GROUP_DM] ) + # check relationship with the given user id + # and the user id making the request + friends = await app.user_storage.are_friends_with(user_id, peer_id) + + if not friends: + raise BadRequest('Cant insert peer into dm') + + if ctype == ChannelType.DM: + dm_chan = await _gdm_create( + user_id, other_id + ) + + await _gdm_add_recipient(dm_chan, peer_id, user_id=user_id) + + return jsonify( + await app.storage.get_channel(dm_chan) + ) + @bp.route('//recipients/', methods=['DELETE']) async def remove_from_group_dm(dm_chan, user_id): """Remove users from group dm.""" user_id = await token_check() - ctype = await channel_check( + _ctype, owner_id = await channel_check( user_id, dm_chan, only=ChannelType.GROUP_DM ) + + if owner_id != user_id: + raise Forbidden('You are now the owner of the group DM') + + await _gdm_remove_recipient(dm_chan, user_id) + return '', 204 diff --git a/litecord/blueprints/users.py b/litecord/blueprints/users.py index 969a043..8e9deb7 100644 --- a/litecord/blueprints/users.py +++ b/litecord/blueprints/users.py @@ -384,9 +384,10 @@ async def get_profile(peer_id: int): return '', 404 mutuals = await app.user_storage.get_mutual_guilds(user_id, peer_id) + friends = await app.user_storage.are_friends_with(user_id, peer_id) # don't return a proper card if no guilds are being shared. - if not mutuals: + if not mutuals and not friends: return '', 404 # actual premium status is determined by that diff --git a/litecord/storage.py b/litecord/storage.py index 3f96058..d44e035 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -353,6 +353,33 @@ class Storage: return list(map(_overwrite_convert, overwrite_rows)) + async def _gdm_recipient_ids(self, channel_id: int) -> List[int]: + """Get the list of user IDs that are recipients of the + given Group DM.""" + user_ids = await self.db.fetch(""" + SELECT member_id + FROM group_dm_members + JOIN users + ON member_id = users.id + WHERE group_dm_members.id = $1 + ORDER BY username DESC + """, channel_id) + + return [r['member_id'] for r in user_ids] + + async def _gdm_recipients(self, channel_id: int) -> List[int]: + """Get the list of users that are recipients of the + given Group DM.""" + recipients = await self._gdm_recipient_ids(channel_id) + res = [] + + for user_id in recipients: + res.append( + await self.get_user(user_id) + ) + + return res + async def get_channel(self, channel_id: int) -> Dict[str, Any]: """Fetch a single channel's information.""" chan_type = await self.get_chan_type(channel_id) @@ -401,8 +428,17 @@ class Storage: drow['id'] = str(drow['id']) return drow elif ctype == ChannelType.GROUP_DM: - # TODO: group dms - pass + gdm_row = await self.db.fetchrow(""" + SELECT id, owner_id::text, name, icon + FROM group_dm_channels + WHERE id = $1 + """, channel_id) + + drow = dict(gdm_row) + recipients + drow['recipients'] = await self._gdm_recipients(channel_id) + + return drow return None