diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py index 9c9253a..a27fcb4 100644 --- a/litecord/blueprints/checks.py +++ b/litecord/blueprints/checks.py @@ -17,6 +17,8 @@ along with this program. If not, see . """ +from typing import Union, List + from quart import current_app as app from litecord.enums import ChannelType, GUILD_CHANS @@ -53,16 +55,23 @@ async def guild_owner_check(user_id: int, guild_id: int): raise Forbidden('You are not the owner of the guild') -async def channel_check(user_id, channel_id): +async def channel_check(user_id, channel_id, *, + only: Union[ChannelType, List[ChannelType]] = None): """Check if the current user is authorized to read the channel's information.""" chan_type = await app.storage.get_chan_type(channel_id) if chan_type is None: - raise ChannelNotFound(f'channel type not found') + raise ChannelNotFound('channel type not found') ctype = ChannelType(chan_type) + if not isinstance(only, list): + only = [only] + + if only and ctype not in only: + raise ChannelNotFound('invalid channel type') + if ctype in GUILD_CHANS: guild_id = await app.db.fetchval(""" SELECT guild_id @@ -77,6 +86,9 @@ async def channel_check(user_id, channel_id): peer_id = await app.storage.get_dm_peer(channel_id, user_id) return ctype, peer_id + if ctype == ChannelType.GROUP_DM: + return ctype + async def guild_perm_check(user_id, guild_id, permission: str): """Check guild permissions for a user.""" diff --git a/litecord/blueprints/dm_channels.py b/litecord/blueprints/dm_channels.py index c908e7d..310150b 100644 --- a/litecord/blueprints/dm_channels.py +++ b/litecord/blueprints/dm_channels.py @@ -20,6 +20,10 @@ along with this program. If not, see . from quart import Blueprint, request, current_app as app, jsonify from logbook import Logger +from litecord.blueprints.auth import token_check +from litecord.blueprints.checks import channel_check +from litecord.enums import ChannelType + log = Logger(__name__) bp = Blueprint('dm_channels', __name__) @@ -27,10 +31,17 @@ bp = Blueprint('dm_channels', __name__) @bp.route('//receipients/', methods=['PUT']) async def add_to_group_dm(dm_chan, user_id): """Adds a member to a group dm OR creates a group dm.""" - pass + user_id = await token_check() + ctype = await channel_check( + user_id, dm_chan, + only=[ChannelType.DM, ChannelType.GROUP_DM] + ) @bp.route('//recipients/', methods=['DELETE']) async def remove_from_group_dm(dm_chan, user_id): """Remove users from group dm.""" - pass + user_id = await token_check() + ctype = await channel_check( + user_id, dm_chan, only=ChannelType.GROUP_DM + )