checks: add only kwarg to filter allowed channels in route

- dm_channels: add channel_check usage
This commit is contained in:
Luna 2019-02-08 18:40:21 -03:00
parent 682a527a55
commit 1bb2a46d9e
2 changed files with 27 additions and 4 deletions

View File

@ -17,6 +17,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import Union, List
from quart import current_app as app from quart import current_app as app
from litecord.enums import ChannelType, GUILD_CHANS from litecord.enums import ChannelType, GUILD_CHANS
@ -53,16 +55,23 @@ async def guild_owner_check(user_id: int, guild_id: int):
raise Forbidden('You are not the owner of the guild') raise Forbidden('You are not the owner of the guild')
async def channel_check(user_id, channel_id): async def channel_check(user_id, channel_id, *,
only: Union[ChannelType, List[ChannelType]] = None):
"""Check if the current user is authorized """Check if the current user is authorized
to read the channel's information.""" to read the channel's information."""
chan_type = await app.storage.get_chan_type(channel_id) chan_type = await app.storage.get_chan_type(channel_id)
if chan_type is None: if chan_type is None:
raise ChannelNotFound(f'channel type not found') raise ChannelNotFound('channel type not found')
ctype = ChannelType(chan_type) ctype = ChannelType(chan_type)
if not isinstance(only, list):
only = [only]
if only and ctype not in only:
raise ChannelNotFound('invalid channel type')
if ctype in GUILD_CHANS: if ctype in GUILD_CHANS:
guild_id = await app.db.fetchval(""" guild_id = await app.db.fetchval("""
SELECT guild_id SELECT guild_id
@ -77,6 +86,9 @@ async def channel_check(user_id, channel_id):
peer_id = await app.storage.get_dm_peer(channel_id, user_id) peer_id = await app.storage.get_dm_peer(channel_id, user_id)
return ctype, peer_id return ctype, peer_id
if ctype == ChannelType.GROUP_DM:
return ctype
async def guild_perm_check(user_id, guild_id, permission: str): async def guild_perm_check(user_id, guild_id, permission: str):
"""Check guild permissions for a user.""" """Check guild permissions for a user."""

View File

@ -20,6 +20,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
from quart import Blueprint, request, current_app as app, jsonify from quart import Blueprint, request, current_app as app, jsonify
from logbook import Logger from logbook import Logger
from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check
from litecord.enums import ChannelType
log = Logger(__name__) log = Logger(__name__)
bp = Blueprint('dm_channels', __name__) bp = Blueprint('dm_channels', __name__)
@ -27,10 +31,17 @@ bp = Blueprint('dm_channels', __name__)
@bp.route('/<int:dm_chan>/receipients/<int:user_id>', methods=['PUT']) @bp.route('/<int:dm_chan>/receipients/<int:user_id>', methods=['PUT'])
async def add_to_group_dm(dm_chan, user_id): async def add_to_group_dm(dm_chan, user_id):
"""Adds a member to a group dm OR creates a group dm.""" """Adds a member to a group dm OR creates a group dm."""
pass user_id = await token_check()
ctype = await channel_check(
user_id, dm_chan,
only=[ChannelType.DM, ChannelType.GROUP_DM]
)
@bp.route('/<int:dm_chan>/recipients/<int:user_id>', methods=['DELETE']) @bp.route('/<int:dm_chan>/recipients/<int:user_id>', methods=['DELETE'])
async def remove_from_group_dm(dm_chan, user_id): async def remove_from_group_dm(dm_chan, user_id):
"""Remove users from group dm.""" """Remove users from group dm."""
pass user_id = await token_check()
ctype = await channel_check(
user_id, dm_chan, only=ChannelType.GROUP_DM
)