diff --git a/litecord/blueprints/checks.py b/litecord/blueprints/checks.py
index 9c9253a..a27fcb4 100644
--- a/litecord/blueprints/checks.py
+++ b/litecord/blueprints/checks.py
@@ -17,6 +17,8 @@ along with this program. If not, see .
"""
+from typing import Union, List
+
from quart import current_app as app
from litecord.enums import ChannelType, GUILD_CHANS
@@ -53,16 +55,23 @@ async def guild_owner_check(user_id: int, guild_id: int):
raise Forbidden('You are not the owner of the guild')
-async def channel_check(user_id, channel_id):
+async def channel_check(user_id, channel_id, *,
+ only: Union[ChannelType, List[ChannelType]] = None):
"""Check if the current user is authorized
to read the channel's information."""
chan_type = await app.storage.get_chan_type(channel_id)
if chan_type is None:
- raise ChannelNotFound(f'channel type not found')
+ raise ChannelNotFound('channel type not found')
ctype = ChannelType(chan_type)
+ if not isinstance(only, list):
+ only = [only]
+
+ if only and ctype not in only:
+ raise ChannelNotFound('invalid channel type')
+
if ctype in GUILD_CHANS:
guild_id = await app.db.fetchval("""
SELECT guild_id
@@ -77,6 +86,9 @@ async def channel_check(user_id, channel_id):
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
return ctype, peer_id
+ if ctype == ChannelType.GROUP_DM:
+ return ctype
+
async def guild_perm_check(user_id, guild_id, permission: str):
"""Check guild permissions for a user."""
diff --git a/litecord/blueprints/dm_channels.py b/litecord/blueprints/dm_channels.py
index c908e7d..310150b 100644
--- a/litecord/blueprints/dm_channels.py
+++ b/litecord/blueprints/dm_channels.py
@@ -20,6 +20,10 @@ along with this program. If not, see .
from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger
+from litecord.blueprints.auth import token_check
+from litecord.blueprints.checks import channel_check
+from litecord.enums import ChannelType
+
log = Logger(__name__)
bp = Blueprint('dm_channels', __name__)
@@ -27,10 +31,17 @@ bp = Blueprint('dm_channels', __name__)
@bp.route('//receipients/', methods=['PUT'])
async def add_to_group_dm(dm_chan, user_id):
"""Adds a member to a group dm OR creates a group dm."""
- pass
+ user_id = await token_check()
+ ctype = await channel_check(
+ user_id, dm_chan,
+ only=[ChannelType.DM, ChannelType.GROUP_DM]
+ )
@bp.route('//recipients/', methods=['DELETE'])
async def remove_from_group_dm(dm_chan, user_id):
"""Remove users from group dm."""
- pass
+ user_id = await token_check()
+ ctype = await channel_check(
+ user_id, dm_chan, only=ChannelType.GROUP_DM
+ )