channel.messages: add permission checks

this commit only adds permission checking to most parts of the
message endpoints.

 - channel.messages: fix extract_limit's default param
 - channel.messages: check send_messages, mention_everyone, send_tts_messages
 - channel.messages: check manage_messages
 - blueprints.checks: add guild_perm_check, channel_perm_check
 - errors: add error_code property, change some inheritance
 - permissions: fix base_permissions
 - storage: fix get_reactions
 - storage: remove print-debug
 - run: use error_code property when given
This commit is contained in:
Luna Mendes 2018-11-04 19:55:21 -03:00
parent 87dd70b4d9
commit da8b049174
6 changed files with 94 additions and 26 deletions

View File

@ -4,7 +4,7 @@ from logbook import Logger
from litecord.blueprints.auth import token_check from litecord.blueprints.auth import token_check
from litecord.blueprints.checks import channel_check from litecord.blueprints.checks import channel_check, channel_perm_check
from litecord.blueprints.dms import try_dm_state from litecord.blueprints.dms import try_dm_state
from litecord.errors import MessageNotFound, Forbidden, BadRequest from litecord.errors import MessageNotFound, Forbidden, BadRequest
from litecord.enums import MessageType, ChannelType, GUILD_CHANS from litecord.enums import MessageType, ChannelType, GUILD_CHANS
@ -18,7 +18,7 @@ bp = Blueprint('channel_messages', __name__)
def extract_limit(request, default: int = 50): def extract_limit(request, default: int = 50):
try: try:
limit = int(request.args.get('limit', 50)) limit = int(request.args.get('limit', default))
if limit not in range(0, 100): if limit not in range(0, 100):
raise ValueError() raise ValueError()
@ -142,12 +142,25 @@ async def create_message(channel_id):
user_id = await token_check() user_id = await token_check()
ctype, guild_id = await channel_check(user_id, channel_id) ctype, guild_id = await channel_check(user_id, channel_id)
if ctype in GUILD_CHANS:
await channel_perm_check(user_id, channel_id, 'send_messages')
j = validate(await request.get_json(), MESSAGE_CREATE) j = validate(await request.get_json(), MESSAGE_CREATE)
message_id = get_snowflake() message_id = get_snowflake()
# TODO: check SEND_MESSAGES permission
# TODO: check connection to the gateway # TODO: check connection to the gateway
mentions_everyone = ('@everyone' in j['content'] and
await channel_perm_check(
user_id, channel_id, 'mention_everyone', False
)
)
is_tts = (j.get('tts', False) and
await channel_perm_check(
user_id, channel_id, 'send_tts_messages', False
))
await app.db.execute( await app.db.execute(
""" """
INSERT INTO messages (id, channel_id, author_id, content, tts, INSERT INTO messages (id, channel_id, author_id, content, tts,
@ -159,11 +172,9 @@ async def create_message(channel_id):
user_id, user_id,
j['content'], j['content'],
# TODO: check SEND_TTS_MESSAGES is_tts,
j.get('tts', False), mentions_everyone,
# TODO: check MENTION_EVERYONE permissions
'@everyone' in j['content'],
int(j.get('nonce', 0)), int(j.get('nonce', 0)),
MessageType.DEFAULT.value MessageType.DEFAULT.value
) )
@ -238,8 +249,13 @@ async def delete_message(channel_id, message_id):
WHERE messages.id = $1 WHERE messages.id = $1
""", message_id) """, message_id)
# TODO: MANAGE_MESSAGES permission check by_perm = await channel_perm_check(
if author_id != user_id: user_id, channel_id, 'manage_messages', False
)
by_ownership = author_id == user_id
if not by_perm and not by_ownership:
raise Forbidden('You can not delete this message') raise Forbidden('You can not delete this message')
await app.db.execute(""" await app.db.execute("""

View File

@ -1,7 +1,10 @@
from quart import current_app as app from quart import current_app as app
from ..enums import ChannelType, GUILD_CHANS from litecord.enums import ChannelType, GUILD_CHANS
from ..errors import GuildNotFound, ChannelNotFound, Forbidden from litecord.errors import (
GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions
)
from litecord.permissions import base_permissions, get_permissions
async def guild_check(user_id: int, guild_id: int): async def guild_check(user_id: int, guild_id: int):
@ -54,3 +57,27 @@ async def channel_check(user_id, channel_id):
if ctype == ChannelType.DM: if ctype == ChannelType.DM:
peer_id = await app.storage.get_dm_peer(channel_id, user_id) peer_id = await app.storage.get_dm_peer(channel_id, user_id)
return ctype, peer_id return ctype, peer_id
async def guild_perm_check(user_id, guild_id, permission: str):
"""Check guild permissions for a user."""
base_perms = await base_permissions(user_id, guild_id)
hasperm = getattr(base_perms.bits, permission)
if not hasperm:
raise MissingPermissions('Missing permissions.')
async def channel_perm_check(user_id, channel_id,
permission: str, raise_err=True):
"""Check channel permissions for a user."""
base_perms = await get_permissions(user_id, channel_id)
hasperm = getattr(base_perms.bits, permission)
print(base_perms)
print(base_perms.binary)
if not hasperm and raise_err:
raise MissingPermissions('Missing permissions.')
return hasperm

View File

@ -29,22 +29,26 @@ class NotFound(LitecordError):
status_code = 404 status_code = 404
class GuildNotFound(LitecordError): class GuildNotFound(NotFound):
status_code = 404 error_code = 10004
class ChannelNotFound(LitecordError): class ChannelNotFound(NotFound):
status_code = 404 error_code = 10003
class MessageNotFound(LitecordError): class MessageNotFound(NotFound):
status_code = 404 error_code = 10008
class Ratelimited(LitecordError): class Ratelimited(LitecordError):
status_code = 429 status_code = 429
class MissingPermissions(Forbidden):
error_code = 50013
class WebsocketClose(Exception): class WebsocketClose(Exception):
@property @property
def code(self): def code(self):

View File

@ -52,6 +52,9 @@ class Permissions(ctypes.Union):
def __init__(self, val: int): def __init__(self, val: int):
self.binary = val self.binary = val
def __repr__(self):
return f'<Permissions binary={self.binary}>'
def __int__(self): def __int__(self):
return self.binary return self.binary
@ -88,14 +91,25 @@ async def base_permissions(member_id, guild_id) -> Permissions:
WHERE guild_id = $1 WHERE guild_id = $1
""", guild_id) """, guild_id)
permissions = everyone_perms permissions = Permissions(everyone_perms)
role_perms = await app.db.fetch(""" role_ids = await app.db.fetch("""
SELECT permissions SELECT role_id
FROM roles FROM member_roles
WHERE guild_id = $1 AND user_id = $2 WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id) """, guild_id, member_id)
role_perms = []
for row in role_ids:
rperm = await app.db.fetchval("""
SELECT permissions
FROM roles
WHERE id = $1
""", row['role_id'])
role_perms.append(rperm)
for perm_num in role_perms: for perm_num in role_perms:
permissions.binary |= perm_num permissions.binary |= perm_num
@ -180,6 +194,11 @@ async def compute_overwrites(base_perms, user_id, channel_id: int,
async def get_permissions(member_id, channel_id): async def get_permissions(member_id, channel_id):
"""Get all the permissions for a user in a channel.""" """Get all the permissions for a user in a channel."""
guild_id = await app.storage.guild_from_channel(channel_id) guild_id = await app.storage.guild_from_channel(channel_id)
# for non guild channels
if not guild_id:
return ALL_PERMISSIONS
base_perms = await base_permissions(member_id, guild_id) base_perms = await base_permissions(member_id, guild_id)
return await compute_overwrites(base_perms, member_id, return await compute_overwrites(base_perms, member_id,

View File

@ -567,8 +567,9 @@ class Storage:
reactions = await self.db.fetch(""" reactions = await self.db.fetch("""
SELECT user_id, emoji_type, emoji_id, emoji_text SELECT user_id, emoji_type, emoji_id, emoji_text
FROM message_reactions FROM message_reactions
WHERE message_id = $1
ORDER BY react_ts ORDER BY react_ts
""") """, message_id)
# ordered list of emoji # ordered list of emoji
emoji = [] emoji = []
@ -616,15 +617,12 @@ class Storage:
stats = react_stats[main_emoji] stats = react_stats[main_emoji]
stats['count'] += 1 stats['count'] += 1
print(row['user_id'], user_id)
if row['user_id'] == user_id: if row['user_id'] == user_id:
stats['me'] = True stats['me'] = True
# after processing reaction counts, # after processing reaction counts,
# we get them in the same order # we get them in the same order
# they were defined in the first loop. # they were defined in the first loop.
print(emoji)
print(react_stats)
return list(map(react_stats.get, emoji)) return list(map(react_stats.get, emoji))
async def get_message(self, message_id: int, user_id=None) -> Dict: async def get_message(self, message_id: int, user_id=None) -> Dict:

6
run.py
View File

@ -202,9 +202,13 @@ async def handle_litecord_err(err):
except IndexError: except IndexError:
ejson = {} ejson = {}
try:
ejson['code'] = err.error_code
except AttributeError:
pass
return jsonify({ return jsonify({
'error': True, 'error': True,
# 'code': err.code,
'status': err.status_code, 'status': err.status_code,
'message': err.message, 'message': err.message,
**ejson **ejson