mirror of https://gitlab.com/litecord/litecord.git
channel.messages: add permission checks
this commit only adds permission checking to most parts of the message endpoints. - channel.messages: fix extract_limit's default param - channel.messages: check send_messages, mention_everyone, send_tts_messages - channel.messages: check manage_messages - blueprints.checks: add guild_perm_check, channel_perm_check - errors: add error_code property, change some inheritance - permissions: fix base_permissions - storage: fix get_reactions - storage: remove print-debug - run: use error_code property when given
This commit is contained in:
parent
87dd70b4d9
commit
da8b049174
|
|
@ -4,7 +4,7 @@ from logbook import Logger
|
|||
|
||||
|
||||
from litecord.blueprints.auth import token_check
|
||||
from litecord.blueprints.checks import channel_check
|
||||
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||
from litecord.blueprints.dms import try_dm_state
|
||||
from litecord.errors import MessageNotFound, Forbidden, BadRequest
|
||||
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
||||
|
|
@ -18,7 +18,7 @@ bp = Blueprint('channel_messages', __name__)
|
|||
|
||||
def extract_limit(request, default: int = 50):
|
||||
try:
|
||||
limit = int(request.args.get('limit', 50))
|
||||
limit = int(request.args.get('limit', default))
|
||||
|
||||
if limit not in range(0, 100):
|
||||
raise ValueError()
|
||||
|
|
@ -142,12 +142,25 @@ async def create_message(channel_id):
|
|||
user_id = await token_check()
|
||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||
|
||||
if ctype in GUILD_CHANS:
|
||||
await channel_perm_check(user_id, channel_id, 'send_messages')
|
||||
|
||||
j = validate(await request.get_json(), MESSAGE_CREATE)
|
||||
message_id = get_snowflake()
|
||||
|
||||
# TODO: check SEND_MESSAGES permission
|
||||
# TODO: check connection to the gateway
|
||||
|
||||
mentions_everyone = ('@everyone' in j['content'] and
|
||||
await channel_perm_check(
|
||||
user_id, channel_id, 'mention_everyone', False
|
||||
)
|
||||
)
|
||||
|
||||
is_tts = (j.get('tts', False) and
|
||||
await channel_perm_check(
|
||||
user_id, channel_id, 'send_tts_messages', False
|
||||
))
|
||||
|
||||
await app.db.execute(
|
||||
"""
|
||||
INSERT INTO messages (id, channel_id, author_id, content, tts,
|
||||
|
|
@ -159,11 +172,9 @@ async def create_message(channel_id):
|
|||
user_id,
|
||||
j['content'],
|
||||
|
||||
# TODO: check SEND_TTS_MESSAGES
|
||||
j.get('tts', False),
|
||||
is_tts,
|
||||
mentions_everyone,
|
||||
|
||||
# TODO: check MENTION_EVERYONE permissions
|
||||
'@everyone' in j['content'],
|
||||
int(j.get('nonce', 0)),
|
||||
MessageType.DEFAULT.value
|
||||
)
|
||||
|
|
@ -238,8 +249,13 @@ async def delete_message(channel_id, message_id):
|
|||
WHERE messages.id = $1
|
||||
""", message_id)
|
||||
|
||||
# TODO: MANAGE_MESSAGES permission check
|
||||
if author_id != user_id:
|
||||
by_perm = await channel_perm_check(
|
||||
user_id, channel_id, 'manage_messages', False
|
||||
)
|
||||
|
||||
by_ownership = author_id == user_id
|
||||
|
||||
if not by_perm and not by_ownership:
|
||||
raise Forbidden('You can not delete this message')
|
||||
|
||||
await app.db.execute("""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from quart import current_app as app
|
||||
|
||||
from ..enums import ChannelType, GUILD_CHANS
|
||||
from ..errors import GuildNotFound, ChannelNotFound, Forbidden
|
||||
from litecord.enums import ChannelType, GUILD_CHANS
|
||||
from litecord.errors import (
|
||||
GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions
|
||||
)
|
||||
from litecord.permissions import base_permissions, get_permissions
|
||||
|
||||
|
||||
async def guild_check(user_id: int, guild_id: int):
|
||||
|
|
@ -54,3 +57,27 @@ async def channel_check(user_id, channel_id):
|
|||
if ctype == ChannelType.DM:
|
||||
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
||||
return ctype, peer_id
|
||||
|
||||
|
||||
async def guild_perm_check(user_id, guild_id, permission: str):
|
||||
"""Check guild permissions for a user."""
|
||||
base_perms = await base_permissions(user_id, guild_id)
|
||||
hasperm = getattr(base_perms.bits, permission)
|
||||
|
||||
if not hasperm:
|
||||
raise MissingPermissions('Missing permissions.')
|
||||
|
||||
|
||||
async def channel_perm_check(user_id, channel_id,
|
||||
permission: str, raise_err=True):
|
||||
"""Check channel permissions for a user."""
|
||||
base_perms = await get_permissions(user_id, channel_id)
|
||||
hasperm = getattr(base_perms.bits, permission)
|
||||
|
||||
print(base_perms)
|
||||
print(base_perms.binary)
|
||||
|
||||
if not hasperm and raise_err:
|
||||
raise MissingPermissions('Missing permissions.')
|
||||
|
||||
return hasperm
|
||||
|
|
|
|||
|
|
@ -29,22 +29,26 @@ class NotFound(LitecordError):
|
|||
status_code = 404
|
||||
|
||||
|
||||
class GuildNotFound(LitecordError):
|
||||
status_code = 404
|
||||
class GuildNotFound(NotFound):
|
||||
error_code = 10004
|
||||
|
||||
|
||||
class ChannelNotFound(LitecordError):
|
||||
status_code = 404
|
||||
class ChannelNotFound(NotFound):
|
||||
error_code = 10003
|
||||
|
||||
|
||||
class MessageNotFound(LitecordError):
|
||||
status_code = 404
|
||||
class MessageNotFound(NotFound):
|
||||
error_code = 10008
|
||||
|
||||
|
||||
class Ratelimited(LitecordError):
|
||||
status_code = 429
|
||||
|
||||
|
||||
class MissingPermissions(Forbidden):
|
||||
error_code = 50013
|
||||
|
||||
|
||||
class WebsocketClose(Exception):
|
||||
@property
|
||||
def code(self):
|
||||
|
|
|
|||
|
|
@ -52,6 +52,9 @@ class Permissions(ctypes.Union):
|
|||
def __init__(self, val: int):
|
||||
self.binary = val
|
||||
|
||||
def __repr__(self):
|
||||
return f'<Permissions binary={self.binary}>'
|
||||
|
||||
def __int__(self):
|
||||
return self.binary
|
||||
|
||||
|
|
@ -88,14 +91,25 @@ async def base_permissions(member_id, guild_id) -> Permissions:
|
|||
WHERE guild_id = $1
|
||||
""", guild_id)
|
||||
|
||||
permissions = everyone_perms
|
||||
permissions = Permissions(everyone_perms)
|
||||
|
||||
role_perms = await app.db.fetch("""
|
||||
SELECT permissions
|
||||
FROM roles
|
||||
role_ids = await app.db.fetch("""
|
||||
SELECT role_id
|
||||
FROM member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
""", guild_id, member_id)
|
||||
|
||||
role_perms = []
|
||||
|
||||
for row in role_ids:
|
||||
rperm = await app.db.fetchval("""
|
||||
SELECT permissions
|
||||
FROM roles
|
||||
WHERE id = $1
|
||||
""", row['role_id'])
|
||||
|
||||
role_perms.append(rperm)
|
||||
|
||||
for perm_num in role_perms:
|
||||
permissions.binary |= perm_num
|
||||
|
||||
|
|
@ -180,6 +194,11 @@ async def compute_overwrites(base_perms, user_id, channel_id: int,
|
|||
async def get_permissions(member_id, channel_id):
|
||||
"""Get all the permissions for a user in a channel."""
|
||||
guild_id = await app.storage.guild_from_channel(channel_id)
|
||||
|
||||
# for non guild channels
|
||||
if not guild_id:
|
||||
return ALL_PERMISSIONS
|
||||
|
||||
base_perms = await base_permissions(member_id, guild_id)
|
||||
|
||||
return await compute_overwrites(base_perms, member_id,
|
||||
|
|
|
|||
|
|
@ -567,8 +567,9 @@ class Storage:
|
|||
reactions = await self.db.fetch("""
|
||||
SELECT user_id, emoji_type, emoji_id, emoji_text
|
||||
FROM message_reactions
|
||||
WHERE message_id = $1
|
||||
ORDER BY react_ts
|
||||
""")
|
||||
""", message_id)
|
||||
|
||||
# ordered list of emoji
|
||||
emoji = []
|
||||
|
|
@ -616,15 +617,12 @@ class Storage:
|
|||
stats = react_stats[main_emoji]
|
||||
stats['count'] += 1
|
||||
|
||||
print(row['user_id'], user_id)
|
||||
if row['user_id'] == user_id:
|
||||
stats['me'] = True
|
||||
|
||||
# after processing reaction counts,
|
||||
# we get them in the same order
|
||||
# they were defined in the first loop.
|
||||
print(emoji)
|
||||
print(react_stats)
|
||||
return list(map(react_stats.get, emoji))
|
||||
|
||||
async def get_message(self, message_id: int, user_id=None) -> Dict:
|
||||
|
|
|
|||
Loading…
Reference in New Issue