mirror of https://gitlab.com/litecord/litecord.git
checks: add only kwarg to filter allowed channels in route
- dm_channels: add channel_check usage
This commit is contained in:
parent
682a527a55
commit
1bb2a46d9e
|
|
@ -17,6 +17,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
||||||
from litecord.enums import ChannelType, GUILD_CHANS
|
from litecord.enums import ChannelType, GUILD_CHANS
|
||||||
|
|
@ -53,16 +55,23 @@ async def guild_owner_check(user_id: int, guild_id: int):
|
||||||
raise Forbidden('You are not the owner of the guild')
|
raise Forbidden('You are not the owner of the guild')
|
||||||
|
|
||||||
|
|
||||||
async def channel_check(user_id, channel_id):
|
async def channel_check(user_id, channel_id, *,
|
||||||
|
only: Union[ChannelType, List[ChannelType]] = None):
|
||||||
"""Check if the current user is authorized
|
"""Check if the current user is authorized
|
||||||
to read the channel's information."""
|
to read the channel's information."""
|
||||||
chan_type = await app.storage.get_chan_type(channel_id)
|
chan_type = await app.storage.get_chan_type(channel_id)
|
||||||
|
|
||||||
if chan_type is None:
|
if chan_type is None:
|
||||||
raise ChannelNotFound(f'channel type not found')
|
raise ChannelNotFound('channel type not found')
|
||||||
|
|
||||||
ctype = ChannelType(chan_type)
|
ctype = ChannelType(chan_type)
|
||||||
|
|
||||||
|
if not isinstance(only, list):
|
||||||
|
only = [only]
|
||||||
|
|
||||||
|
if only and ctype not in only:
|
||||||
|
raise ChannelNotFound('invalid channel type')
|
||||||
|
|
||||||
if ctype in GUILD_CHANS:
|
if ctype in GUILD_CHANS:
|
||||||
guild_id = await app.db.fetchval("""
|
guild_id = await app.db.fetchval("""
|
||||||
SELECT guild_id
|
SELECT guild_id
|
||||||
|
|
@ -77,6 +86,9 @@ async def channel_check(user_id, channel_id):
|
||||||
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
||||||
return ctype, peer_id
|
return ctype, peer_id
|
||||||
|
|
||||||
|
if ctype == ChannelType.GROUP_DM:
|
||||||
|
return ctype
|
||||||
|
|
||||||
|
|
||||||
async def guild_perm_check(user_id, guild_id, permission: str):
|
async def guild_perm_check(user_id, guild_id, permission: str):
|
||||||
"""Check guild permissions for a user."""
|
"""Check guild permissions for a user."""
|
||||||
|
|
|
||||||
|
|
@ -20,6 +20,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
from quart import Blueprint, request, current_app as app, jsonify
|
from quart import Blueprint, request, current_app as app, jsonify
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
|
from litecord.blueprints.auth import token_check
|
||||||
|
from litecord.blueprints.checks import channel_check
|
||||||
|
from litecord.enums import ChannelType
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
bp = Blueprint('dm_channels', __name__)
|
bp = Blueprint('dm_channels', __name__)
|
||||||
|
|
||||||
|
|
@ -27,10 +31,17 @@ bp = Blueprint('dm_channels', __name__)
|
||||||
@bp.route('/<int:dm_chan>/receipients/<int:user_id>', methods=['PUT'])
|
@bp.route('/<int:dm_chan>/receipients/<int:user_id>', methods=['PUT'])
|
||||||
async def add_to_group_dm(dm_chan, user_id):
|
async def add_to_group_dm(dm_chan, user_id):
|
||||||
"""Adds a member to a group dm OR creates a group dm."""
|
"""Adds a member to a group dm OR creates a group dm."""
|
||||||
pass
|
user_id = await token_check()
|
||||||
|
ctype = await channel_check(
|
||||||
|
user_id, dm_chan,
|
||||||
|
only=[ChannelType.DM, ChannelType.GROUP_DM]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/<int:dm_chan>/recipients/<int:user_id>', methods=['DELETE'])
|
@bp.route('/<int:dm_chan>/recipients/<int:user_id>', methods=['DELETE'])
|
||||||
async def remove_from_group_dm(dm_chan, user_id):
|
async def remove_from_group_dm(dm_chan, user_id):
|
||||||
"""Remove users from group dm."""
|
"""Remove users from group dm."""
|
||||||
pass
|
user_id = await token_check()
|
||||||
|
ctype = await channel_check(
|
||||||
|
user_id, dm_chan, only=ChannelType.GROUP_DM
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue