diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index dae0270..3e9b7b3 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -18,12 +18,18 @@ along with this program. If not, see . """ """ -Main code for Lazy Guild implementation in litecord. +lazy guilds: + + the lazy guild api docs (which are heavily based off this implementation) + can be found on + + https://luna.gitlab.io/discord-unofficial-docs/lazy_guilds.html """ + import asyncio -from dataclasses import dataclass, asdict, field from collections import defaultdict from typing import Any, List, Dict, Union +from dataclasses import dataclass, asdict, field from logbook import Logger @@ -85,11 +91,14 @@ class MemberList: def __bool__(self): """Return if the current member list is fully initialized.""" + # asdict comes from dataclasses list_dict = asdict(self) # ignore the bool status of overwrites - return all(bool(list_dict[k]) - for k in ('groups', 'data', 'presences', 'members')) + return all( + bool(list_dict[k]) + for k in ('groups', 'data', 'presences', 'members') + ) def __iter__(self): """Iterate over all groups in the correct order. @@ -104,8 +113,12 @@ class MemberList: yield group, self.data[group.gid] @property - def iter_non_empty(self): - """Only iterate through non-empty groups""" + def iter_non_empty(self) -> tuple: + """Only iterate through non-empty groups. + + Note that while the offline group can be empty, it is always + yielded out, to comply with Discord. + """ for group, member_ids in self: count = len(member_ids) @@ -183,6 +196,13 @@ def _to_simple_group(presence: dict) -> str: async def everyone_allow(gml) -> bool: + """Return if the '@everyone' role can access a given member list. + + This is important in regards to member list IDs, since if the '@everyone' + role can access the list, then the list is downgraded to an 'everyone' list. + + If the role can't access the list, then the list keeps its list ID. + """ everyone_perms = await role_permissions( gml.guild_id, gml.guild_id, @@ -190,20 +210,7 @@ async def everyone_allow(gml) -> bool: storage=gml.storage ) - return everyone_perms.bits.read_messages - - -def display_name(member_nicks: Dict[str, str], presence: Presence) -> str: - """Return the display name of a presence. - - Used to sort groups. - """ - uid = presence['user']['id'] - - uname = presence['user']['username'] - nick = member_nicks.get(uid) - - return nick or uname + return bool(everyone_perms.bits.read_messages) def merge(member: dict, presence: Presence) -> dict: @@ -344,7 +351,8 @@ class GuildMemberList: return group_id - def _can_read_chan(self, group: GroupInfo): + def _can_read_chan(self, group: GroupInfo) -> bool: + """Return if a given group can acess the channel""" # get the base role perms role_perms = group.permissions @@ -359,9 +367,9 @@ class GuildMemberList: # if the role can read messages, then its # part of the group. - return final_perms.bits.read_messages + return bool(final_perms.bits.read_messages) - async def get_role_groups(self) -> List[GroupInfo]: + async def _get_role_groups(self) -> List[GroupInfo]: """Get role information, but only: - the ID - the name @@ -382,15 +390,19 @@ class GuildMemberList: """, self.guild_id) hoisted = [ - GroupInfo(row['id'], row['name'], - row['position'], - Permissions(row['permissions'])) + GroupInfo( + row['id'], row['name'], + row['position'], + Permissions(row['permissions']) + ) for row in roledata if row['hoist'] ] # sort role list by position - hoisted = sorted(hoisted, key=lambda group: group.position, - reverse=True) + hoisted = sorted( + hoisted, key=lambda group: group.position, + reverse=True + ) # we need to store the overwrites since # we have incoming presences to manage. @@ -398,9 +410,9 @@ class GuildMemberList: return list(filter(self._can_read_chan, hoisted)) - async def set_groups(self): + async def _set_groups(self): """Get the groups for the member list.""" - role_groups = await self.get_role_groups() + role_groups = await self._get_role_groups() # inject default groups 'online' and 'offline' # their position is always going to be the last ones. @@ -409,9 +421,9 @@ class GuildMemberList: GroupInfo('offline', 'offline', MAX_ROLES + 2, 0) ] - async def get_group_for_member(self, member_id: int, - roles: List[Union[str, int]], - status: str) -> GroupID: + async def _get_group_for_member(self, member_id: int, + roles: List[Union[str, int]], + status: str) -> GroupID: """Return a fitting group ID for the member.""" member_roles = list(map(int, roles)) @@ -435,7 +447,7 @@ class GuildMemberList: for member_id in member_ids: presence = self.list.presences[member_id] - group_id = await self.get_group_for_member( + group_id = await self._get_group_for_member( member_id, presence['roles'], presence['status'] ) @@ -451,19 +463,12 @@ class GuildMemberList: self.list.members[member_id] = member self.list.data[group_id].append(member_id) - async def get_member_nicks_dict(self) -> dict: - """Get a dictionary with nickname information.""" - members = await self.storage.get_member_data(self.guild_id) + def _display_name(self, member_id: int) -> str: + """Get the display name for a given member. - # make a dictionary of member ids to nicknames - # so we don't need to keep querying the db on - # every loop iteration - member_nicks = {m['user']['id']: m.get('nick') - for m in members} - - return member_nicks - - def display_name(self, member_id: int): + This is more efficient than the old function (not method) of same + name, as we dont need to pass nickname information to it. + """ member = self.list.members.get(member_id) if not member_id: @@ -476,10 +481,9 @@ class GuildMemberList: async def _sort_groups(self): for member_ids in self.list.data.values(): - # this should update the list in-place member_ids.sort( - key=self.display_name) + key=self._display_name) async def __init_member_list(self): """Generate the main member list with groups.""" @@ -492,7 +496,7 @@ class GuildMemberList: self.list.presences = {int(p['user']['id']): p for p in presences} - await self.set_groups() + await self._set_groups() log.debug('init: {} members, {} groups', len(member_ids), @@ -514,7 +518,7 @@ class GuildMemberList: finally: self._list_lock.release() - def get_member_as_item(self, member_id: int) -> dict: + def _get_member_as_item(self, member_id: int) -> dict: """Get an item representing a member.""" member = self.list.members[member_id] presence = self.list.presences[member_id] @@ -550,17 +554,17 @@ class GuildMemberList: for member_id in member_ids: res.append({ - 'member': self.get_member_as_item(member_id) + 'member': self._get_member_as_item(member_id) }) return res - async def sub(self, _session_id: str): - """Subscribe a shard to the member list.""" - await self._init_check() - def unsub(self, session_id: str): - """Unsubscribe a shard from the member list""" + """Unsubscribe a shard from the member list + + Subscription for the member list is handled via the + :meth:`GuildMemberList.shard_query` method. + """ try: self.state.pop(session_id) except KeyError: @@ -574,8 +578,11 @@ class GuildMemberList: if not self.state: self._set_empty_list() - def get_state(self, session_id: str): - """Get the state for a session id.""" + def _get_state(self, session_id: str): + """Get the state for a session id. + + Wrapper for :meth:`StateManager.fetch_raw` + """ try: state = self.state_man.fetch_raw(session_id) return state @@ -605,7 +612,7 @@ class GuildMemberList: ] } - states = map(self.get_state, session_ids) + states = map(self._get_state, session_ids) states = filter(lambda state: state is not None, states) dispatched = [] @@ -618,7 +625,7 @@ class GuildMemberList: return dispatched - async def resync(self, session_ids: int, item_index: int) -> List[str]: + async def _resync(self, session_ids: int, item_index: int) -> List[str]: """Send a SYNC event to all states that are subscribed to an item. Returns @@ -653,13 +660,13 @@ class GuildMemberList: return result - async def resync_by_item(self, item_index: int): + async def _resync_by_item(self, item_index: int): """Resync but only giving the item index.""" if item_index is None: return [] - return await self.resync( - self.get_subs(item_index), + return await self._resync( + self._get_subs(item_index), item_index ) @@ -722,8 +729,9 @@ class GuildMemberList: # send SYNCs to the state that requested await self._dispatch_sess([session_id], ops) - def get_item_index(self, user_id: Union[str, int]) -> int: + def _get_item_index(self, user_id: Union[str, int]) -> int: """Get the item index a user is on.""" + # NOTE: this is inefficient user_id = int(user_id) index = 1 @@ -741,7 +749,7 @@ class GuildMemberList: return None - def get_group_item_index(self, group_id: GroupID) -> int: + def _get_group_item_index(self, group_id: GroupID) -> int: """Get the item index a group is on.""" index = 0 @@ -753,7 +761,7 @@ class GuildMemberList: return None - def state_is_subbed(self, item_index, session_id: str) -> bool: + def _is_subbed(self, item_index, session_id: str) -> bool: """Return if a state's ranges include the given item index.""" @@ -765,15 +773,21 @@ class GuildMemberList: return False - def get_subs(self, item_index: int) -> filter: + def _get_subs(self, item_index: int) -> filter: """Get the list of subscribed states to a given item.""" return filter( - lambda sess_id: self.state_is_subbed(item_index, sess_id), + lambda sess_id: self._is_subbed(item_index, sess_id), self.state.keys() ) async def _pres_update_simple(self, user_id: int): - item_index = self.get_item_index(user_id) + """Handler for simple presence updates. + + Simple presence updates are just a single UPDATE operator for + the client. usually called when a user still maintains their role + list but changes to from online to idle/dnd and vice-versa. + """ + item_index = self._get_item_index(user_id) if item_index is None: log.warning('lazy guild got invalid pres update uid={}', @@ -781,7 +795,7 @@ class GuildMemberList: return [] item = self.items[item_index] - session_ids = self.get_subs(item_index) + session_ids = self._get_subs(item_index) # simple update means we just give an UPDATE # operation @@ -818,8 +832,8 @@ class GuildMemberList: ops = [] - old_user_index = self.get_item_index(user_id) - old_group_index = self.get_group_item_index(old_group) + old_user_index = self._get_item_index(user_id) + old_group_index = self._get_group_item_index(old_group) ops.append(Operation('DELETE', { 'index': old_user_index @@ -831,7 +845,7 @@ class GuildMemberList: await self._sort_groups() - new_user_index = self.get_item_index(user_id) + new_user_index = self._get_item_index(user_id) ops.append(Operation('INSERT', { 'index': new_user_index, @@ -845,7 +859,7 @@ class GuildMemberList: # the first member in the group. if self.list.is_birth(new_group) and new_group != 'offline': ops.append(Operation('INSERT', { - 'index': self.get_group_item_index(new_group), + 'index': self._get_group_item_index(new_group), 'item': { 'group': str(new_group), 'count': 1 } @@ -858,8 +872,8 @@ class GuildMemberList: 'index': old_group_index, })) - session_ids_old = list(self.get_subs(old_user_index)) - session_ids_new = list(self.get_subs(new_user_index)) + session_ids_old = list(self._get_subs(old_user_index)) + session_ids_new = list(self._get_subs(new_user_index)) # session_ids = set(session_ids_old + session_ids_new) # NOTE: this section is what a realistic implementation @@ -876,8 +890,8 @@ class GuildMemberList: # ) # merge both results together - return (await self.resync(session_ids_old, old_user_index) + - await self.resync(session_ids_new, new_user_index)) + return (await self._resync(session_ids_old, old_user_index) + + await self._resync(session_ids_new, new_user_index)) async def new_member(self, user_id: int): """Insert a new member.""" @@ -906,7 +920,7 @@ class GuildMemberList: self.list.members[user_id] = member # find a group for the newcomer - group_id = await self.get_group_for_member( + group_id = await self._get_group_for_member( user_id, member['roles'], pres['status']) if group_id is None: @@ -917,13 +931,13 @@ class GuildMemberList: self.list.data[group_id].append(user_id) await self._sort_groups() - user_index = self.get_item_index(user_id) + user_index = self._get_item_index(user_id) if not user_index: log.warning('lazy: new uid {} was not assigned idx', user_id) - return await self.resync_by_item(user_index) + return await self._resync_by_item(user_index) async def remove_member(self, user_id: int): """Remove a member from the list.""" @@ -933,17 +947,13 @@ class GuildMemberList: return # we need the old index to resync later on - old_idx = self.get_item_index(user_id) - - def is_valid_state(session_id): - state = self.get_state(session_id) - return state.user_id != user_id + old_idx = self._get_item_index(user_id) # for now, remove any of the users' subscribed states state_keys = self.state.keys() for session_id in state_keys: - state = self.get_state(session_id) + state = self._get_state(session_id) # if unknown state, remove from the subscriber list if state is None: @@ -978,7 +988,7 @@ class GuildMemberList: log.warning('lazy: unknown member uid {}', user_id) return - group_id = await self.get_group_for_member( + group_id = await self._get_group_for_member( user_id, member['roles'], pres['status']) if not group_id: @@ -992,7 +1002,7 @@ class GuildMemberList: return # tell everyone about the removal. - await self.resync_by_item(old_idx) + await self._resync_by_item(old_idx) async def update_user(self, user_id: int): """Called for user updates such as avatar or username.""" @@ -1009,8 +1019,8 @@ class GuildMemberList: await self.storage.get_user(user_id) # redispatch - user_idx = self.get_item_index(user_id) - return await self.resync_by_item(user_idx) + user_idx = self._get_item_index(user_id) + return await self._resync_by_item(user_idx) async def pres_update(self, user_id: int, partial_presence: Presence): @@ -1075,7 +1085,7 @@ class GuildMemberList: # calculate a possible new group # TODO: handle when new_group is None (member loses perms) - new_group = await self.get_group_for_member( + new_group = await self._get_group_for_member( user_id, roles, status) log.debug('pres update: gid={} cid={} old_g={} new_g={}', @@ -1171,13 +1181,13 @@ class GuildMemberList: """ role_id = int(role['id']) - old_index = self.get_group_item_index(role_id) + old_index = self._get_group_item_index(role_id) if not old_index: log.warning('lazy role_pos_update: unknown group {}', role_id) return - old_sessions = list(self.get_subs(old_index)) + old_sessions = list(self._get_subs(old_index)) groups_idx = self._get_role_as_group_idx(role_id) if groups_idx is None: @@ -1202,10 +1212,10 @@ class GuildMemberList: [g.gid for g in new_groups]) self.list.groups = new_groups - new_index = self.get_group_item_index(role_id) + new_index = self._get_group_item_index(role_id) - return (await self.resync(old_sessions, old_index) + - await self.resync_by_item(new_index)) + return (await self._resync(old_sessions, old_index) + + await self._resync_by_item(new_index)) async def role_update(self, role: dict): """Update a role. @@ -1267,7 +1277,7 @@ class GuildMemberList: # states we'll resend the list info to. # find the item id for the group info - role_item_index = self.get_group_item_index(role_id) + role_item_index = self._get_group_item_index(role_id) # we only resync when we actually have an item to resync # we don't have items to resync when we: @@ -1279,7 +1289,7 @@ class GuildMemberList: # using a filter object would cause problems # as we only resync AFTER we delete the group - sess_ids_resync = (list(self.get_subs(role_item_index)) + sess_ids_resync = (list(self._get_subs(role_item_index)) if role_item_index is not None else []) @@ -1328,7 +1338,7 @@ class GuildMemberList: log.debug('there are {} session ids to resync (for item {})', len(sess_ids_resync), role_item_index) - return await self.resync(sess_ids_resync, role_item_index) + return await self._resync(sess_ids_resync, role_item_index) async def chan_update(self): """Called then a channel's data has been updated.""" @@ -1359,7 +1369,7 @@ class GuildMemberList: self.guild_id = None self.channel_id = None self.main = None - self.list = MemberList + self._set_empty_list() self.state = {} @@ -1412,6 +1422,7 @@ class LazyGuildDispatcher(Dispatcher): )) async def unsub(self, chan_id, session_id): + """Unsubscribe a session from the list.""" gml = await self.get_gml(chan_id) gml.unsub(session_id)