From 87dd70b4d9a16ae1cd60d0163219878006a2383f Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 17:54:48 -0300 Subject: [PATCH] permissions: add basic permission api - litecord.auth: insert request.user_id - storage: add get_member_role_ids --- litecord/auth.py | 10 +++- litecord/permissions.py | 129 ++++++++++++++++++++++++++++++++++++++++ litecord/storage.py | 18 ++++-- 3 files changed, 150 insertions(+), 7 deletions(-) diff --git a/litecord/auth.py b/litecord/auth.py index af241b5..498aa59 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -58,6 +58,12 @@ async def raw_token_check(token, db=None): async def token_check(): """Check token information.""" + # first, check if the request info already has a uid + try: + return request.user_id + except AttributeError: + pass + try: token = request.headers['Authorization'] except KeyError: @@ -66,4 +72,6 @@ async def token_check(): if token.startswith('Bot '): token = token.replace('Bot ', '') - return await raw_token_check(token) + user_id = await raw_token_check(token) + request.user_id = user_id + return user_id diff --git a/litecord/permissions.py b/litecord/permissions.py index c5c5966..850759f 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -1,5 +1,7 @@ import ctypes +from quart import current_app as app, request + # so we don't keep repeating the same # type for all the fields _i = ctypes.c_uint8 @@ -55,3 +57,130 @@ class Permissions(ctypes.Union): def numby(self): return self.binary + + +ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111) + + +async def base_permissions(member_id, guild_id) -> Permissions: + """Compute the base permissions for a given user. + + Base permissions are + (permissions from @everyone role) + + (permissions from any other role the member has) + + This will give ALL_PERMISSIONS if base permissions + has the Administrator bit set. + """ + owner_id = await app.db.fetchval(""" + SELECT owner_id + FROM guilds + WHERE id = $1 + """, guild_id) + + if owner_id == member_id: + return ALL_PERMISSIONS + + # get permissions for @everyone + everyone_perms = await app.db.fetchval(""" + SELECT permissions + FROM roles + WHERE guild_id = $1 + """, guild_id) + + permissions = everyone_perms + + role_perms = await app.db.fetch(""" + SELECT permissions + FROM roles + WHERE guild_id = $1 AND user_id = $2 + """, guild_id, member_id) + + for perm_num in role_perms: + permissions.binary |= perm_num + + if permissions.bits.administrator: + return ALL_PERMISSIONS + + return permissions + + +def _mix(perms: Permissions, overwrite: dict) -> Permissions: + # we make a copy of the binary representation + # so we don't modify the old perms in-place + # which could be an unwanted side-effect + result = perms.binary + + # negate the permissions that are denied + result &= ~overwrite['deny'] + + # combine the permissions that are allowed + result |= overwrite['allow'] + + return Permissions(result) + + +def _overwrite_mix(perms: Permissions, overwrites: dict, + target_id: int) -> Permissions: + overwrite = overwrites.get(target_id) + + if overwrite: + # only mix if overwrite found + return _mix(perms, overwrite) + + return perms + + +async def compute_overwrites(base_perms, user_id, channel_id: int, + guild_id: int = None): + """Compute the permissions in the context of a channel.""" + + if base_perms.bits.administrator: + return ALL_PERMISSIONS + + perms = base_perms + + # list of overwrites + overwrites = await app.storage.chan_overwrites(channel_id) + + if not guild_id: + guild_id = await app.storage.guild_from_channel(channel_id) + + # make it a map for better usage + overwrites = {o['id']: o for o in overwrites} + + perms = _overwrite_mix(perms, overwrites, guild_id) + + # apply role specific overwrites + allow, deny = 0, 0 + + # fetch roles from user and convert to int + role_ids = await app.storage.get_member_role_ids(guild_id, user_id) + role_ids = map(int, role_ids) + + # make the allow and deny binaries + for role_id in role_ids: + overwrite = overwrites.get(role_id) + if overwrite: + allow |= overwrite['allow'] + deny |= overwrite['deny'] + + # final step for roles: mix + perms = _mix(perms, { + 'allow': allow, + 'deny': deny + }) + + # apply member specific overwrites + perms = _overwrite_mix(perms, overwrites, user_id) + + return perms + + +async def get_permissions(member_id, channel_id): + """Get all the permissions for a user in a channel.""" + guild_id = await app.storage.guild_from_channel(channel_id) + base_perms = await base_permissions(member_id, guild_id) + + return await compute_overwrites(base_perms, member_id, + channel_id, guild_id) diff --git a/litecord/storage.py b/litecord/storage.py index 313b4f2..8cb5f8a 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -166,7 +166,9 @@ class Storage: WHERE guild_id = $1 and user_id = $2 """, guild_id, member_id) - async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: + async def get_member_role_ids(self, guild_id: int, + member_id: int) -> List[int]: + """Get a list of role IDs that are on a member.""" roles = await self.db.fetch(""" SELECT role_id::text FROM member_roles @@ -186,6 +188,10 @@ class Storage: VALUES ($1, $2, $3) """, member_id, guild_id, guild_id) + return list(map(str, roles)) + + async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: + roles = await self.get_member_role_ids(guild_id, member_id) return { 'user': await self.get_user(member_id), 'nick': row['nickname'], @@ -309,7 +315,7 @@ class Storage: WHERE channels.id = $1 """, channel_id) - async def _chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]: + async def chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]: overwrite_rows = await self.db.fetch(""" SELECT target_type, target_role, target_user, allow, deny FROM channel_overwrites @@ -355,8 +361,8 @@ class Storage: dbase['type'] = chan_type res = await self._channels_extra(dbase) - res['permission_overwrites'] = \ - list(await self._chan_overwrites(channel_id)) + res['permission_overwrites'] = await self.chan_overwrites( + channel_id) res['id'] = str(res['id']) return res @@ -421,8 +427,8 @@ class Storage: res = await self._channels_extra(drow) - res['permission_overwrites'] = \ - list(await self._chan_overwrites(row['id'])) + res['permission_overwrites'] = await self.chan_overwrites( + row['id']) # Making sure. res['id'] = str(res['id'])