From 04d89a221401cdf1a0f4620d9f1c0d32c01b6cf3 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Mon, 5 Nov 2018 22:04:48 -0300 Subject: [PATCH] pubsub.lazy_guilds: major refactor This makes the whole process of generating a member list easier to understand and modify (from my point of view). The actual event dispatching functionality is not on this commit. - permissions: add optional storage kwarg --- litecord/permissions.py | 49 +++-- litecord/pubsub/lazy_guild.py | 337 +++++++++++++++++++++------------- 2 files changed, 237 insertions(+), 149 deletions(-) diff --git a/litecord/permissions.py b/litecord/permissions.py index f68e259..f5fa63c 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -65,7 +65,7 @@ class Permissions(ctypes.Union): ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111) -async def base_permissions(member_id, guild_id) -> Permissions: +async def base_permissions(member_id, guild_id, storage=None) -> Permissions: """Compute the base permissions for a given user. Base permissions are @@ -75,7 +75,11 @@ async def base_permissions(member_id, guild_id) -> Permissions: This will give ALL_PERMISSIONS if base permissions has the Administrator bit set. """ - owner_id = await app.db.fetchval(""" + + if not storage: + storage = app.storage + + owner_id = await storage.db.fetchval(""" SELECT owner_id FROM guilds WHERE id = $1 @@ -85,7 +89,7 @@ async def base_permissions(member_id, guild_id) -> Permissions: return ALL_PERMISSIONS # get permissions for @everyone - everyone_perms = await app.db.fetchval(""" + everyone_perms = await storage.db.fetchval(""" SELECT permissions FROM roles WHERE guild_id = $1 @@ -93,7 +97,7 @@ async def base_permissions(member_id, guild_id) -> Permissions: permissions = Permissions(everyone_perms) - role_ids = await app.db.fetch(""" + role_ids = await storage.db.fetch(""" SELECT role_id FROM member_roles WHERE guild_id = $1 AND user_id = $2 @@ -102,7 +106,7 @@ async def base_permissions(member_id, guild_id) -> Permissions: role_perms = [] for row in role_ids: - rperm = await app.db.fetchval(""" + rperm = await storage.db.fetchval(""" SELECT permissions FROM roles WHERE id = $1 @@ -119,7 +123,7 @@ async def base_permissions(member_id, guild_id) -> Permissions: return permissions -def _mix(perms: Permissions, overwrite: dict) -> Permissions: +def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions: # we make a copy of the binary representation # so we don't modify the old perms in-place # which could be an unwanted side-effect @@ -134,20 +138,22 @@ def _mix(perms: Permissions, overwrite: dict) -> Permissions: return Permissions(result) -def _overwrite_mix(perms: Permissions, overwrites: dict, - target_id: int) -> Permissions: +def overwrite_find_mix(perms: Permissions, overwrites: dict, + target_id: int) -> Permissions: overwrite = overwrites.get(target_id) if overwrite: # only mix if overwrite found - return _mix(perms, overwrite) + return overwrite_mix(perms, overwrite) return perms async def compute_overwrites(base_perms, user_id, channel_id: int, - guild_id: int = None): + guild_id: int = None, storage=None): """Compute the permissions in the context of a channel.""" + if not storage: + storage = app.storage if base_perms.bits.administrator: return ALL_PERMISSIONS @@ -155,21 +161,21 @@ async def compute_overwrites(base_perms, user_id, channel_id: int, perms = base_perms # list of overwrites - overwrites = await app.storage.chan_overwrites(channel_id) + overwrites = await storage.chan_overwrites(channel_id) if not guild_id: - guild_id = await app.storage.guild_from_channel(channel_id) + guild_id = await storage.guild_from_channel(channel_id) # make it a map for better usage overwrites = {o['id']: o for o in overwrites} - perms = _overwrite_mix(perms, overwrites, guild_id) + perms = overwrite_find_mix(perms, overwrites, guild_id) # apply role specific overwrites allow, deny = 0, 0 # fetch roles from user and convert to int - role_ids = await app.storage.get_member_role_ids(guild_id, user_id) + role_ids = await storage.get_member_role_ids(guild_id, user_id) role_ids = map(int, role_ids) # make the allow and deny binaries @@ -180,26 +186,29 @@ async def compute_overwrites(base_perms, user_id, channel_id: int, deny |= overwrite['deny'] # final step for roles: mix - perms = _mix(perms, { + perms = overwrite_mix(perms, { 'allow': allow, 'deny': deny }) # apply member specific overwrites - perms = _overwrite_mix(perms, overwrites, user_id) + perms = overwrite_find_mix(perms, overwrites, user_id) return perms -async def get_permissions(member_id, channel_id): +async def get_permissions(member_id, channel_id, *, storage=None): """Get all the permissions for a user in a channel.""" - guild_id = await app.storage.guild_from_channel(channel_id) + if not storage: + storage = app.storage + + guild_id = await storage.guild_from_channel(channel_id) # for non guild channels if not guild_id: return ALL_PERMISSIONS - base_perms = await base_permissions(member_id, guild_id) + base_perms = await base_permissions(member_id, guild_id, storage) return await compute_overwrites(base_perms, member_id, - channel_id, guild_id) + channel_id, guild_id, storage) diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index d010f78..2202234 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -2,15 +2,57 @@ Main code for Lazy Guild implementation in litecord. """ import pprint +from dataclasses import dataclass, asdict from collections import defaultdict -from typing import Any, List, Dict +from typing import Any, List, Dict, Union from logbook import Logger -from .dispatcher import Dispatcher +from litecord.pubsub.dispatcher import Dispatcher +from litecord.permissions import ( + Permissions, overwrite_find_mix, get_permissions +) log = Logger(__name__) +GroupID = Union[int, str] +Presence = Dict[str, Any] + + +@dataclass +class GroupInfo: + gid: GroupID + name: str + position: int + permissions: Permissions + + +@dataclass +class MemberList: + groups: List[GroupInfo] = None + group_info: Dict[GroupID, GroupInfo] = None + data: Dict[GroupID, Presence] = None + overwrites: Dict[int, Dict[str, Any]] = None + + def __bool__(self): + """Return if the current member list is fully initialized.""" + list_dict = asdict(self) + return all(v is not None for v in list_dict.values()) + + def __iter__(self): + """Iterate over all groups in the correct order. + + Yields a tuple containing :class:`GroupInfo` and + the List[Presence] for the group. + """ + for group in self.groups: + yield group, self.data[group.gid] + + +def _to_simple_group(presence: dict) -> str: + """Return a simple group (not a role), given a presence.""" + return 'offline' if presence['status'] == 'offline' else 'online' + class GuildMemberList: """This class stores the current member list information @@ -41,7 +83,6 @@ class GuildMemberList: """ def __init__(self, guild_id: int, channel_id: int, main_lg): - self.main_lg = main_lg self.guild_id = guild_id self.channel_id = channel_id @@ -52,167 +93,205 @@ class GuildMemberList: self.presence = main.app.presence self.state_man = main.app.state_manager - self.member_list = None - self.items = None + self.list = MemberList(None, None, None, None) #: holds the state of subscribed shards # to this channels' member list self.state = set() + def _set_empty_list(self): + self.list = MemberList(None, None, None, None) + async def _init_check(self): """Check if the member list is initialized before messing with it.""" - if self.member_list is None: + if not self.list: await self._init_member_list() - async def get_roles(self) -> List[Dict[str, Any]]: + async def _fetch_overwrites(self): + overwrites = await self.storage.chan_overwrites(self.channel_id) + overwrites = {int(ov['id']): ov for ov in overwrites} + self.list.overwrites = overwrites + + def _calc_member_group(self, roles: List[int], status: str): + """Calculate the best fitting group for a member, + given their roles and their current status.""" + try: + # the first group in the list + # that the member is entitled to is + # the selected group for the member. + group_id = next(g.gid for g in self.list.groups + if g.gid in roles) + except StopIteration: + # no group was found, so we fallback + # to simple group" + group_id = _to_simple_group({'status': status}) + + return group_id + + async def get_roles(self) -> List[GroupInfo]: """Get role information, but only: - the ID - the name - the position - - of all HOISTED roles.""" - # TODO: write own query for this - # TODO: calculate channel overrides - roles = await self.storage.get_role_data(self.guild_id) + - the permissions - return [{ - 'id': role['id'], - 'name': role['name'], - 'position': role['position'] - } for role in roles if role['hoist']] + of all HOISTED roles AND roles that + have permissions to read the channel + being referred to this :class:`GuildMemberList` + instance. + + The list is sorted by position. + """ + roledata = await self.storage.db.fetch(""" + SELECT id, name, hoist, position, permissions + FROM roles + WHERE guild_id = $1 + """, self.guild_id) + + hoisted = [ + GroupInfo(row['id'], row['name'], + row['position'], row['permissions']) + for row in roledata if row['hoist'] + ] + + # sort role list by position + hoisted = sorted(hoisted, key=lambda group: group.position) + + # we need to store them for later on + # for members + await self._fetch_overwrites() + + def _can_read_chan(group: GroupInfo): + # get the base role perms + role_perms = group.permissions + + # then the final perms for that role if + # any overwrite exists in the channel + final_perms = overwrite_find_mix( + role_perms, self.list.overwrites, group.gid) + + # update the group's permissions + # with the mixed ones + group.permissions = final_perms + + # if the role can read messages, then its + # part of the group. + return final_perms.bits.read_messages + + return list(filter(_can_read_chan, hoisted)) + + async def set_groups(self): + """Get the groups for the member list.""" + role_groups = await self.get_roles() + role_ids = [g.gid for g in role_groups] + + self.list.groups = role_ids + ['online', 'offline'] + + # inject default groups 'online' and 'offline' + self.list.groups = role_ids + [ + GroupInfo('online', 'online', -1, -1), + GroupInfo('offline', 'offline', -1, -1) + ] + self.list.group_info = {g.gid: g for g in role_groups} + + async def _pass_1(self, guild_presences: List[Presence]): + """First pass on generating the member list. + + This assigns all presences a single group. + """ + for presence in guild_presences: + member_id = int(presence['user']['id']) + + # list of roles for the member + member_roles = list(map(int, presence['roles'])) + + # get the member's permissions relative to the channel + # (accounting for channel overwrites) + member_perms = await get_permissions( + member_id, self.channel_id, storage=self.storage) + + if not member_perms.bits.read_messages: + continue + + # if the member is offline, we + # default give them the offline group. + status = presence['status'] + group_id = ('offline' if status == 'offline' + else self._calc_member_group(member_roles, status)) + + self.list.data[group_id].append(presence) + + async def _sort_groups(self): + members = await self.storage.get_member_data(self.guild_id) + + # make a dictionary of member ids to nicknames + # so we don't need to keep querying the db on + # every loop iteration + member_nicks = {m['user']['id']: m.get('nick') + for m in members} + + for group_members in self.list.data.values(): + def display_name(presence: Presence) -> str: + uid = presence['user']['id'] + + uname = presence['user']['username'] + nick = member_nicks.get(uid) + + return nick or uname + + # this should update the list in-place + group_members.sort(key=display_name) async def _init_member_list(self): - """Fill in :attr:`GuildMemberList.member_list` - with information about the guilds' members.""" + """Generate the main member list with groups.""" member_ids = await self.storage.get_member_ids(self.guild_id) guild_presences = await self.presence.guild_presences( member_ids, self.guild_id) - guild_roles = await self.get_roles() - - # sort by position - guild_roles.sort(key=lambda role: role['position']) - roleids = [r['id'] for r in guild_roles] - - # groups are: - # - roles that are hoisted - # - "online" and "offline", with "online" - # being for people without any roles. - - groups = roleids + ['online', 'offline'] + await self.set_groups() log.debug('{} presences, {} groups', - len(guild_presences), len(groups)) + len(guild_presences), + len(self.list.groups)) - group_data = {group: [] for group in groups} + self.list.data = {group.gid: [] for group in self.list.groups} - print('group data', group_data) + # first pass: set which presences + # go to which groups + await self._pass_1(guild_presences) - def _try_hier(role_id: str, roleids: list): - """Try to fetch a role's position in the hierarchy""" - try: - return roleids.index(role_id) - except ValueError: - # the given role isn't on a group - # so it doesn't count for our purposes. - return 0 + # second pass: sort each group's members + # by the display name + await self._sort_groups() - for presence in guild_presences: - # simple group (online or offline) - # we'll decide on the best group for the presence later on - simple_group = ('offline' - if presence['status'] == 'offline' - else 'online') + @property + def items(self) -> list: + """Main items list.""" - # get the best possible role - roles = sorted( - presence['roles'], - key=lambda role_id: _try_hier(role_id, roleids) - ) + # TODO: maybe make this stored in the list + # so we don't need to keep regenning? - try: - best_role_id = roles[0] - except IndexError: - # no hoisted roles exist in the guild, assign - # the @everyone role - best_role_id = str(self.guild_id) - - print('best role', best_role_id, str(self.guild_id)) - print('simple group assign', simple_group) - - # if the best role is literally the @everyone role, - # this user has no hoisted roles - if best_role_id == str(self.guild_id): - # this user has no roles, put it on online/offline - group_data[simple_group].append(presence) - continue - - # this user has a best_role that isn't the - # @everyone role, so we'll put them in the respective group - try: - group_data[best_role_id].append(presence) - except KeyError: - group_data[simple_group].append(presence) - - # go through each group and sort the resulting members by display name - - members = await self.storage.get_member_data(self.guild_id) - member_nicks = {member['user']['id']: member.get('nick') - for member in members} - - # now we'll sort each group by their display name - # (can be their current nickname OR their username - # if no nickname is set) - print('pre-sorted group data') - pprint.pprint(group_data) - - for _, group_list in group_data.items(): - def display_name(presence: dict) -> str: - uid = presence['user']['id'] - - uname = presence['user']['username'] - nick = member_nicks[uid] - - return nick or uname - - group_list.sort(key=display_name) - - pprint.pprint(group_data) - - self.member_list = { - 'groups': groups, - 'data': group_data - } - - def get_items(self) -> list: - """Generate the main items list,""" - if self.member_list is None: + if not self.list: return [] - if self.items: - return self.items - - groups = self.member_list['groups'] - res = [] - for group in groups: - members = self.member_list['data'][group] + # NOTE: maybe use map()? + for group, presences in self.list: res.append({ 'group': { - 'id': group, - 'count': len(members), + 'id': group.gid, + 'count': len(presences), } }) - for member in members: + for presence in presences: res.append({ - 'member': member + 'member': presence }) - self.items = res return res async def sub(self, session_id: str): @@ -230,7 +309,7 @@ class GuildMemberList: # uninitialized) for a future subscriber. if not self.state: - self.member_list = None + self._set_empty_list() async def shard_query(self, session_id: str, ranges: list): """Send a GUILD_MEMBER_LIST_UPDATE event @@ -259,7 +338,7 @@ class GuildMemberList: # subscribe the state to this list await self.sub(session_id) - # TODO: subscribe shard to 'everyone' + # TODO: subscribe shard to the 'everyone' member list # and forward the query to that list reply = { @@ -273,9 +352,9 @@ class GuildMemberList: 'groups': [ { - 'count': len(self.member_list['data'][group]), - 'id': group - } for group in self.member_list['groups'] + 'count': len(presences), + 'id': group.gid + } for group, presences in self.list ], 'ops': [], @@ -288,12 +367,12 @@ class GuildMemberList: if itemcount < 0: continue - items = self.get_items() + # TODO: subscribe user to the slice reply['ops'].append({ 'op': 'SYNC', 'range': [start, end], - 'items': items[start:end], + 'items': self.items[start:end], }) # the first GUILD_MEMBER_LIST_UPDATE for a shard