mirror of https://gitlab.com/litecord/litecord.git
channel.messages: add permission checks
this commit only adds permission checking to most parts of the message endpoints. - channel.messages: fix extract_limit's default param - channel.messages: check send_messages, mention_everyone, send_tts_messages - channel.messages: check manage_messages - blueprints.checks: add guild_perm_check, channel_perm_check - errors: add error_code property, change some inheritance - permissions: fix base_permissions - storage: fix get_reactions - storage: remove print-debug - run: use error_code property when given
This commit is contained in:
parent
87dd70b4d9
commit
da8b049174
|
|
@ -4,7 +4,7 @@ from logbook import Logger
|
||||||
|
|
||||||
|
|
||||||
from litecord.blueprints.auth import token_check
|
from litecord.blueprints.auth import token_check
|
||||||
from litecord.blueprints.checks import channel_check
|
from litecord.blueprints.checks import channel_check, channel_perm_check
|
||||||
from litecord.blueprints.dms import try_dm_state
|
from litecord.blueprints.dms import try_dm_state
|
||||||
from litecord.errors import MessageNotFound, Forbidden, BadRequest
|
from litecord.errors import MessageNotFound, Forbidden, BadRequest
|
||||||
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
||||||
|
|
@ -18,7 +18,7 @@ bp = Blueprint('channel_messages', __name__)
|
||||||
|
|
||||||
def extract_limit(request, default: int = 50):
|
def extract_limit(request, default: int = 50):
|
||||||
try:
|
try:
|
||||||
limit = int(request.args.get('limit', 50))
|
limit = int(request.args.get('limit', default))
|
||||||
|
|
||||||
if limit not in range(0, 100):
|
if limit not in range(0, 100):
|
||||||
raise ValueError()
|
raise ValueError()
|
||||||
|
|
@ -142,12 +142,25 @@ async def create_message(channel_id):
|
||||||
user_id = await token_check()
|
user_id = await token_check()
|
||||||
ctype, guild_id = await channel_check(user_id, channel_id)
|
ctype, guild_id = await channel_check(user_id, channel_id)
|
||||||
|
|
||||||
|
if ctype in GUILD_CHANS:
|
||||||
|
await channel_perm_check(user_id, channel_id, 'send_messages')
|
||||||
|
|
||||||
j = validate(await request.get_json(), MESSAGE_CREATE)
|
j = validate(await request.get_json(), MESSAGE_CREATE)
|
||||||
message_id = get_snowflake()
|
message_id = get_snowflake()
|
||||||
|
|
||||||
# TODO: check SEND_MESSAGES permission
|
|
||||||
# TODO: check connection to the gateway
|
# TODO: check connection to the gateway
|
||||||
|
|
||||||
|
mentions_everyone = ('@everyone' in j['content'] and
|
||||||
|
await channel_perm_check(
|
||||||
|
user_id, channel_id, 'mention_everyone', False
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
is_tts = (j.get('tts', False) and
|
||||||
|
await channel_perm_check(
|
||||||
|
user_id, channel_id, 'send_tts_messages', False
|
||||||
|
))
|
||||||
|
|
||||||
await app.db.execute(
|
await app.db.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO messages (id, channel_id, author_id, content, tts,
|
INSERT INTO messages (id, channel_id, author_id, content, tts,
|
||||||
|
|
@ -159,11 +172,9 @@ async def create_message(channel_id):
|
||||||
user_id,
|
user_id,
|
||||||
j['content'],
|
j['content'],
|
||||||
|
|
||||||
# TODO: check SEND_TTS_MESSAGES
|
is_tts,
|
||||||
j.get('tts', False),
|
mentions_everyone,
|
||||||
|
|
||||||
# TODO: check MENTION_EVERYONE permissions
|
|
||||||
'@everyone' in j['content'],
|
|
||||||
int(j.get('nonce', 0)),
|
int(j.get('nonce', 0)),
|
||||||
MessageType.DEFAULT.value
|
MessageType.DEFAULT.value
|
||||||
)
|
)
|
||||||
|
|
@ -238,8 +249,13 @@ async def delete_message(channel_id, message_id):
|
||||||
WHERE messages.id = $1
|
WHERE messages.id = $1
|
||||||
""", message_id)
|
""", message_id)
|
||||||
|
|
||||||
# TODO: MANAGE_MESSAGES permission check
|
by_perm = await channel_perm_check(
|
||||||
if author_id != user_id:
|
user_id, channel_id, 'manage_messages', False
|
||||||
|
)
|
||||||
|
|
||||||
|
by_ownership = author_id == user_id
|
||||||
|
|
||||||
|
if not by_perm and not by_ownership:
|
||||||
raise Forbidden('You can not delete this message')
|
raise Forbidden('You can not delete this message')
|
||||||
|
|
||||||
await app.db.execute("""
|
await app.db.execute("""
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
||||||
from ..enums import ChannelType, GUILD_CHANS
|
from litecord.enums import ChannelType, GUILD_CHANS
|
||||||
from ..errors import GuildNotFound, ChannelNotFound, Forbidden
|
from litecord.errors import (
|
||||||
|
GuildNotFound, ChannelNotFound, Forbidden, MissingPermissions
|
||||||
|
)
|
||||||
|
from litecord.permissions import base_permissions, get_permissions
|
||||||
|
|
||||||
|
|
||||||
async def guild_check(user_id: int, guild_id: int):
|
async def guild_check(user_id: int, guild_id: int):
|
||||||
|
|
@ -54,3 +57,27 @@ async def channel_check(user_id, channel_id):
|
||||||
if ctype == ChannelType.DM:
|
if ctype == ChannelType.DM:
|
||||||
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
peer_id = await app.storage.get_dm_peer(channel_id, user_id)
|
||||||
return ctype, peer_id
|
return ctype, peer_id
|
||||||
|
|
||||||
|
|
||||||
|
async def guild_perm_check(user_id, guild_id, permission: str):
|
||||||
|
"""Check guild permissions for a user."""
|
||||||
|
base_perms = await base_permissions(user_id, guild_id)
|
||||||
|
hasperm = getattr(base_perms.bits, permission)
|
||||||
|
|
||||||
|
if not hasperm:
|
||||||
|
raise MissingPermissions('Missing permissions.')
|
||||||
|
|
||||||
|
|
||||||
|
async def channel_perm_check(user_id, channel_id,
|
||||||
|
permission: str, raise_err=True):
|
||||||
|
"""Check channel permissions for a user."""
|
||||||
|
base_perms = await get_permissions(user_id, channel_id)
|
||||||
|
hasperm = getattr(base_perms.bits, permission)
|
||||||
|
|
||||||
|
print(base_perms)
|
||||||
|
print(base_perms.binary)
|
||||||
|
|
||||||
|
if not hasperm and raise_err:
|
||||||
|
raise MissingPermissions('Missing permissions.')
|
||||||
|
|
||||||
|
return hasperm
|
||||||
|
|
|
||||||
|
|
@ -29,22 +29,26 @@ class NotFound(LitecordError):
|
||||||
status_code = 404
|
status_code = 404
|
||||||
|
|
||||||
|
|
||||||
class GuildNotFound(LitecordError):
|
class GuildNotFound(NotFound):
|
||||||
status_code = 404
|
error_code = 10004
|
||||||
|
|
||||||
|
|
||||||
class ChannelNotFound(LitecordError):
|
class ChannelNotFound(NotFound):
|
||||||
status_code = 404
|
error_code = 10003
|
||||||
|
|
||||||
|
|
||||||
class MessageNotFound(LitecordError):
|
class MessageNotFound(NotFound):
|
||||||
status_code = 404
|
error_code = 10008
|
||||||
|
|
||||||
|
|
||||||
class Ratelimited(LitecordError):
|
class Ratelimited(LitecordError):
|
||||||
status_code = 429
|
status_code = 429
|
||||||
|
|
||||||
|
|
||||||
|
class MissingPermissions(Forbidden):
|
||||||
|
error_code = 50013
|
||||||
|
|
||||||
|
|
||||||
class WebsocketClose(Exception):
|
class WebsocketClose(Exception):
|
||||||
@property
|
@property
|
||||||
def code(self):
|
def code(self):
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,9 @@ class Permissions(ctypes.Union):
|
||||||
def __init__(self, val: int):
|
def __init__(self, val: int):
|
||||||
self.binary = val
|
self.binary = val
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f'<Permissions binary={self.binary}>'
|
||||||
|
|
||||||
def __int__(self):
|
def __int__(self):
|
||||||
return self.binary
|
return self.binary
|
||||||
|
|
||||||
|
|
@ -88,14 +91,25 @@ async def base_permissions(member_id, guild_id) -> Permissions:
|
||||||
WHERE guild_id = $1
|
WHERE guild_id = $1
|
||||||
""", guild_id)
|
""", guild_id)
|
||||||
|
|
||||||
permissions = everyone_perms
|
permissions = Permissions(everyone_perms)
|
||||||
|
|
||||||
role_perms = await app.db.fetch("""
|
role_ids = await app.db.fetch("""
|
||||||
SELECT permissions
|
SELECT role_id
|
||||||
FROM roles
|
FROM member_roles
|
||||||
WHERE guild_id = $1 AND user_id = $2
|
WHERE guild_id = $1 AND user_id = $2
|
||||||
""", guild_id, member_id)
|
""", guild_id, member_id)
|
||||||
|
|
||||||
|
role_perms = []
|
||||||
|
|
||||||
|
for row in role_ids:
|
||||||
|
rperm = await app.db.fetchval("""
|
||||||
|
SELECT permissions
|
||||||
|
FROM roles
|
||||||
|
WHERE id = $1
|
||||||
|
""", row['role_id'])
|
||||||
|
|
||||||
|
role_perms.append(rperm)
|
||||||
|
|
||||||
for perm_num in role_perms:
|
for perm_num in role_perms:
|
||||||
permissions.binary |= perm_num
|
permissions.binary |= perm_num
|
||||||
|
|
||||||
|
|
@ -180,6 +194,11 @@ async def compute_overwrites(base_perms, user_id, channel_id: int,
|
||||||
async def get_permissions(member_id, channel_id):
|
async def get_permissions(member_id, channel_id):
|
||||||
"""Get all the permissions for a user in a channel."""
|
"""Get all the permissions for a user in a channel."""
|
||||||
guild_id = await app.storage.guild_from_channel(channel_id)
|
guild_id = await app.storage.guild_from_channel(channel_id)
|
||||||
|
|
||||||
|
# for non guild channels
|
||||||
|
if not guild_id:
|
||||||
|
return ALL_PERMISSIONS
|
||||||
|
|
||||||
base_perms = await base_permissions(member_id, guild_id)
|
base_perms = await base_permissions(member_id, guild_id)
|
||||||
|
|
||||||
return await compute_overwrites(base_perms, member_id,
|
return await compute_overwrites(base_perms, member_id,
|
||||||
|
|
|
||||||
|
|
@ -567,8 +567,9 @@ class Storage:
|
||||||
reactions = await self.db.fetch("""
|
reactions = await self.db.fetch("""
|
||||||
SELECT user_id, emoji_type, emoji_id, emoji_text
|
SELECT user_id, emoji_type, emoji_id, emoji_text
|
||||||
FROM message_reactions
|
FROM message_reactions
|
||||||
|
WHERE message_id = $1
|
||||||
ORDER BY react_ts
|
ORDER BY react_ts
|
||||||
""")
|
""", message_id)
|
||||||
|
|
||||||
# ordered list of emoji
|
# ordered list of emoji
|
||||||
emoji = []
|
emoji = []
|
||||||
|
|
@ -616,15 +617,12 @@ class Storage:
|
||||||
stats = react_stats[main_emoji]
|
stats = react_stats[main_emoji]
|
||||||
stats['count'] += 1
|
stats['count'] += 1
|
||||||
|
|
||||||
print(row['user_id'], user_id)
|
|
||||||
if row['user_id'] == user_id:
|
if row['user_id'] == user_id:
|
||||||
stats['me'] = True
|
stats['me'] = True
|
||||||
|
|
||||||
# after processing reaction counts,
|
# after processing reaction counts,
|
||||||
# we get them in the same order
|
# we get them in the same order
|
||||||
# they were defined in the first loop.
|
# they were defined in the first loop.
|
||||||
print(emoji)
|
|
||||||
print(react_stats)
|
|
||||||
return list(map(react_stats.get, emoji))
|
return list(map(react_stats.get, emoji))
|
||||||
|
|
||||||
async def get_message(self, message_id: int, user_id=None) -> Dict:
|
async def get_message(self, message_id: int, user_id=None) -> Dict:
|
||||||
|
|
|
||||||
6
run.py
6
run.py
|
|
@ -202,9 +202,13 @@ async def handle_litecord_err(err):
|
||||||
except IndexError:
|
except IndexError:
|
||||||
ejson = {}
|
ejson = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
ejson['code'] = err.error_code
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
'error': True,
|
'error': True,
|
||||||
# 'code': err.code,
|
|
||||||
'status': err.status_code,
|
'status': err.status_code,
|
||||||
'message': err.message,
|
'message': err.message,
|
||||||
**ejson
|
**ejson
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue