permissions: add basic permission api

- litecord.auth: insert request.user_id
 - storage: add get_member_role_ids
This commit is contained in:
Luna Mendes 2018-11-04 17:54:48 -03:00
parent 818571d336
commit 87dd70b4d9
3 changed files with 150 additions and 7 deletions

View File

@ -58,6 +58,12 @@ async def raw_token_check(token, db=None):
async def token_check(): async def token_check():
"""Check token information.""" """Check token information."""
# first, check if the request info already has a uid
try:
return request.user_id
except AttributeError:
pass
try: try:
token = request.headers['Authorization'] token = request.headers['Authorization']
except KeyError: except KeyError:
@ -66,4 +72,6 @@ async def token_check():
if token.startswith('Bot '): if token.startswith('Bot '):
token = token.replace('Bot ', '') token = token.replace('Bot ', '')
return await raw_token_check(token) user_id = await raw_token_check(token)
request.user_id = user_id
return user_id

View File

@ -1,5 +1,7 @@
import ctypes import ctypes
from quart import current_app as app, request
# so we don't keep repeating the same # so we don't keep repeating the same
# type for all the fields # type for all the fields
_i = ctypes.c_uint8 _i = ctypes.c_uint8
@ -55,3 +57,130 @@ class Permissions(ctypes.Union):
def numby(self): def numby(self):
return self.binary return self.binary
ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111)
async def base_permissions(member_id, guild_id) -> Permissions:
"""Compute the base permissions for a given user.
Base permissions are
(permissions from @everyone role) +
(permissions from any other role the member has)
This will give ALL_PERMISSIONS if base permissions
has the Administrator bit set.
"""
owner_id = await app.db.fetchval("""
SELECT owner_id
FROM guilds
WHERE id = $1
""", guild_id)
if owner_id == member_id:
return ALL_PERMISSIONS
# get permissions for @everyone
everyone_perms = await app.db.fetchval("""
SELECT permissions
FROM roles
WHERE guild_id = $1
""", guild_id)
permissions = everyone_perms
role_perms = await app.db.fetch("""
SELECT permissions
FROM roles
WHERE guild_id = $1 AND user_id = $2
""", guild_id, member_id)
for perm_num in role_perms:
permissions.binary |= perm_num
if permissions.bits.administrator:
return ALL_PERMISSIONS
return permissions
def _mix(perms: Permissions, overwrite: dict) -> Permissions:
# we make a copy of the binary representation
# so we don't modify the old perms in-place
# which could be an unwanted side-effect
result = perms.binary
# negate the permissions that are denied
result &= ~overwrite['deny']
# combine the permissions that are allowed
result |= overwrite['allow']
return Permissions(result)
def _overwrite_mix(perms: Permissions, overwrites: dict,
target_id: int) -> Permissions:
overwrite = overwrites.get(target_id)
if overwrite:
# only mix if overwrite found
return _mix(perms, overwrite)
return perms
async def compute_overwrites(base_perms, user_id, channel_id: int,
guild_id: int = None):
"""Compute the permissions in the context of a channel."""
if base_perms.bits.administrator:
return ALL_PERMISSIONS
perms = base_perms
# list of overwrites
overwrites = await app.storage.chan_overwrites(channel_id)
if not guild_id:
guild_id = await app.storage.guild_from_channel(channel_id)
# make it a map for better usage
overwrites = {o['id']: o for o in overwrites}
perms = _overwrite_mix(perms, overwrites, guild_id)
# apply role specific overwrites
allow, deny = 0, 0
# fetch roles from user and convert to int
role_ids = await app.storage.get_member_role_ids(guild_id, user_id)
role_ids = map(int, role_ids)
# make the allow and deny binaries
for role_id in role_ids:
overwrite = overwrites.get(role_id)
if overwrite:
allow |= overwrite['allow']
deny |= overwrite['deny']
# final step for roles: mix
perms = _mix(perms, {
'allow': allow,
'deny': deny
})
# apply member specific overwrites
perms = _overwrite_mix(perms, overwrites, user_id)
return perms
async def get_permissions(member_id, channel_id):
"""Get all the permissions for a user in a channel."""
guild_id = await app.storage.guild_from_channel(channel_id)
base_perms = await base_permissions(member_id, guild_id)
return await compute_overwrites(base_perms, member_id,
channel_id, guild_id)

View File

@ -166,7 +166,9 @@ class Storage:
WHERE guild_id = $1 and user_id = $2 WHERE guild_id = $1 and user_id = $2
""", guild_id, member_id) """, guild_id, member_id)
async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]: async def get_member_role_ids(self, guild_id: int,
member_id: int) -> List[int]:
"""Get a list of role IDs that are on a member."""
roles = await self.db.fetch(""" roles = await self.db.fetch("""
SELECT role_id::text SELECT role_id::text
FROM member_roles FROM member_roles
@ -186,6 +188,10 @@ class Storage:
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
""", member_id, guild_id, guild_id) """, member_id, guild_id, guild_id)
return list(map(str, roles))
async def _member_dict(self, row, guild_id, member_id) -> Dict[str, Any]:
roles = await self.get_member_role_ids(guild_id, member_id)
return { return {
'user': await self.get_user(member_id), 'user': await self.get_user(member_id),
'nick': row['nickname'], 'nick': row['nickname'],
@ -309,7 +315,7 @@ class Storage:
WHERE channels.id = $1 WHERE channels.id = $1
""", channel_id) """, channel_id)
async def _chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]: async def chan_overwrites(self, channel_id: int) -> List[Dict[str, Any]]:
overwrite_rows = await self.db.fetch(""" overwrite_rows = await self.db.fetch("""
SELECT target_type, target_role, target_user, allow, deny SELECT target_type, target_role, target_user, allow, deny
FROM channel_overwrites FROM channel_overwrites
@ -355,8 +361,8 @@ class Storage:
dbase['type'] = chan_type dbase['type'] = chan_type
res = await self._channels_extra(dbase) res = await self._channels_extra(dbase)
res['permission_overwrites'] = \ res['permission_overwrites'] = await self.chan_overwrites(
list(await self._chan_overwrites(channel_id)) channel_id)
res['id'] = str(res['id']) res['id'] = str(res['id'])
return res return res
@ -421,8 +427,8 @@ class Storage:
res = await self._channels_extra(drow) res = await self._channels_extra(drow)
res['permission_overwrites'] = \ res['permission_overwrites'] = await self.chan_overwrites(
list(await self._chan_overwrites(row['id'])) row['id'])
# Making sure. # Making sure.
res['id'] = str(res['id']) res['id'] = str(res['id'])