mirror of https://gitlab.com/litecord/litecord.git
pubsub.lazy_guild: sanity of mind
- add more comments, always good - remove some unused methods and move some others to private methods - use _set_empty_list in close() method
This commit is contained in:
parent
a69a2423de
commit
f1127d1970
|
|
@ -18,12 +18,18 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
"""
|
||||
Main code for Lazy Guild implementation in litecord.
|
||||
lazy guilds:
|
||||
|
||||
the lazy guild api docs (which are heavily based off this implementation)
|
||||
can be found on
|
||||
|
||||
https://luna.gitlab.io/discord-unofficial-docs/lazy_guilds.html
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from collections import defaultdict
|
||||
from typing import Any, List, Dict, Union
|
||||
from dataclasses import dataclass, asdict, field
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
|
|
@ -85,11 +91,14 @@ class MemberList:
|
|||
|
||||
def __bool__(self):
|
||||
"""Return if the current member list is fully initialized."""
|
||||
# asdict comes from dataclasses
|
||||
list_dict = asdict(self)
|
||||
|
||||
# ignore the bool status of overwrites
|
||||
return all(bool(list_dict[k])
|
||||
for k in ('groups', 'data', 'presences', 'members'))
|
||||
return all(
|
||||
bool(list_dict[k])
|
||||
for k in ('groups', 'data', 'presences', 'members')
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over all groups in the correct order.
|
||||
|
|
@ -104,8 +113,12 @@ class MemberList:
|
|||
yield group, self.data[group.gid]
|
||||
|
||||
@property
|
||||
def iter_non_empty(self):
|
||||
"""Only iterate through non-empty groups"""
|
||||
def iter_non_empty(self) -> tuple:
|
||||
"""Only iterate through non-empty groups.
|
||||
|
||||
Note that while the offline group can be empty, it is always
|
||||
yielded out, to comply with Discord.
|
||||
"""
|
||||
|
||||
for group, member_ids in self:
|
||||
count = len(member_ids)
|
||||
|
|
@ -183,6 +196,13 @@ def _to_simple_group(presence: dict) -> str:
|
|||
|
||||
|
||||
async def everyone_allow(gml) -> bool:
|
||||
"""Return if the '@everyone' role can access a given member list.
|
||||
|
||||
This is important in regards to member list IDs, since if the '@everyone'
|
||||
role can access the list, then the list is downgraded to an 'everyone' list.
|
||||
|
||||
If the role can't access the list, then the list keeps its list ID.
|
||||
"""
|
||||
everyone_perms = await role_permissions(
|
||||
gml.guild_id,
|
||||
gml.guild_id,
|
||||
|
|
@ -190,20 +210,7 @@ async def everyone_allow(gml) -> bool:
|
|||
storage=gml.storage
|
||||
)
|
||||
|
||||
return everyone_perms.bits.read_messages
|
||||
|
||||
|
||||
def display_name(member_nicks: Dict[str, str], presence: Presence) -> str:
|
||||
"""Return the display name of a presence.
|
||||
|
||||
Used to sort groups.
|
||||
"""
|
||||
uid = presence['user']['id']
|
||||
|
||||
uname = presence['user']['username']
|
||||
nick = member_nicks.get(uid)
|
||||
|
||||
return nick or uname
|
||||
return bool(everyone_perms.bits.read_messages)
|
||||
|
||||
|
||||
def merge(member: dict, presence: Presence) -> dict:
|
||||
|
|
@ -344,7 +351,8 @@ class GuildMemberList:
|
|||
|
||||
return group_id
|
||||
|
||||
def _can_read_chan(self, group: GroupInfo):
|
||||
def _can_read_chan(self, group: GroupInfo) -> bool:
|
||||
"""Return if a given group can acess the channel"""
|
||||
# get the base role perms
|
||||
role_perms = group.permissions
|
||||
|
||||
|
|
@ -359,9 +367,9 @@ class GuildMemberList:
|
|||
|
||||
# if the role can read messages, then its
|
||||
# part of the group.
|
||||
return final_perms.bits.read_messages
|
||||
return bool(final_perms.bits.read_messages)
|
||||
|
||||
async def get_role_groups(self) -> List[GroupInfo]:
|
||||
async def _get_role_groups(self) -> List[GroupInfo]:
|
||||
"""Get role information, but only:
|
||||
- the ID
|
||||
- the name
|
||||
|
|
@ -382,15 +390,19 @@ class GuildMemberList:
|
|||
""", self.guild_id)
|
||||
|
||||
hoisted = [
|
||||
GroupInfo(row['id'], row['name'],
|
||||
GroupInfo(
|
||||
row['id'], row['name'],
|
||||
row['position'],
|
||||
Permissions(row['permissions']))
|
||||
Permissions(row['permissions'])
|
||||
)
|
||||
for row in roledata if row['hoist']
|
||||
]
|
||||
|
||||
# sort role list by position
|
||||
hoisted = sorted(hoisted, key=lambda group: group.position,
|
||||
reverse=True)
|
||||
hoisted = sorted(
|
||||
hoisted, key=lambda group: group.position,
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# we need to store the overwrites since
|
||||
# we have incoming presences to manage.
|
||||
|
|
@ -398,9 +410,9 @@ class GuildMemberList:
|
|||
|
||||
return list(filter(self._can_read_chan, hoisted))
|
||||
|
||||
async def set_groups(self):
|
||||
async def _set_groups(self):
|
||||
"""Get the groups for the member list."""
|
||||
role_groups = await self.get_role_groups()
|
||||
role_groups = await self._get_role_groups()
|
||||
|
||||
# inject default groups 'online' and 'offline'
|
||||
# their position is always going to be the last ones.
|
||||
|
|
@ -409,7 +421,7 @@ class GuildMemberList:
|
|||
GroupInfo('offline', 'offline', MAX_ROLES + 2, 0)
|
||||
]
|
||||
|
||||
async def get_group_for_member(self, member_id: int,
|
||||
async def _get_group_for_member(self, member_id: int,
|
||||
roles: List[Union[str, int]],
|
||||
status: str) -> GroupID:
|
||||
"""Return a fitting group ID for the member."""
|
||||
|
|
@ -435,7 +447,7 @@ class GuildMemberList:
|
|||
for member_id in member_ids:
|
||||
presence = self.list.presences[member_id]
|
||||
|
||||
group_id = await self.get_group_for_member(
|
||||
group_id = await self._get_group_for_member(
|
||||
member_id, presence['roles'], presence['status']
|
||||
)
|
||||
|
||||
|
|
@ -451,19 +463,12 @@ class GuildMemberList:
|
|||
self.list.members[member_id] = member
|
||||
self.list.data[group_id].append(member_id)
|
||||
|
||||
async def get_member_nicks_dict(self) -> dict:
|
||||
"""Get a dictionary with nickname information."""
|
||||
members = await self.storage.get_member_data(self.guild_id)
|
||||
def _display_name(self, member_id: int) -> str:
|
||||
"""Get the display name for a given member.
|
||||
|
||||
# make a dictionary of member ids to nicknames
|
||||
# so we don't need to keep querying the db on
|
||||
# every loop iteration
|
||||
member_nicks = {m['user']['id']: m.get('nick')
|
||||
for m in members}
|
||||
|
||||
return member_nicks
|
||||
|
||||
def display_name(self, member_id: int):
|
||||
This is more efficient than the old function (not method) of same
|
||||
name, as we dont need to pass nickname information to it.
|
||||
"""
|
||||
member = self.list.members.get(member_id)
|
||||
|
||||
if not member_id:
|
||||
|
|
@ -476,10 +481,9 @@ class GuildMemberList:
|
|||
|
||||
async def _sort_groups(self):
|
||||
for member_ids in self.list.data.values():
|
||||
|
||||
# this should update the list in-place
|
||||
member_ids.sort(
|
||||
key=self.display_name)
|
||||
key=self._display_name)
|
||||
|
||||
async def __init_member_list(self):
|
||||
"""Generate the main member list with groups."""
|
||||
|
|
@ -492,7 +496,7 @@ class GuildMemberList:
|
|||
self.list.presences = {int(p['user']['id']): p
|
||||
for p in presences}
|
||||
|
||||
await self.set_groups()
|
||||
await self._set_groups()
|
||||
|
||||
log.debug('init: {} members, {} groups',
|
||||
len(member_ids),
|
||||
|
|
@ -514,7 +518,7 @@ class GuildMemberList:
|
|||
finally:
|
||||
self._list_lock.release()
|
||||
|
||||
def get_member_as_item(self, member_id: int) -> dict:
|
||||
def _get_member_as_item(self, member_id: int) -> dict:
|
||||
"""Get an item representing a member."""
|
||||
member = self.list.members[member_id]
|
||||
presence = self.list.presences[member_id]
|
||||
|
|
@ -550,17 +554,17 @@ class GuildMemberList:
|
|||
|
||||
for member_id in member_ids:
|
||||
res.append({
|
||||
'member': self.get_member_as_item(member_id)
|
||||
'member': self._get_member_as_item(member_id)
|
||||
})
|
||||
|
||||
return res
|
||||
|
||||
async def sub(self, _session_id: str):
|
||||
"""Subscribe a shard to the member list."""
|
||||
await self._init_check()
|
||||
|
||||
def unsub(self, session_id: str):
|
||||
"""Unsubscribe a shard from the member list"""
|
||||
"""Unsubscribe a shard from the member list
|
||||
|
||||
Subscription for the member list is handled via the
|
||||
:meth:`GuildMemberList.shard_query` method.
|
||||
"""
|
||||
try:
|
||||
self.state.pop(session_id)
|
||||
except KeyError:
|
||||
|
|
@ -574,8 +578,11 @@ class GuildMemberList:
|
|||
if not self.state:
|
||||
self._set_empty_list()
|
||||
|
||||
def get_state(self, session_id: str):
|
||||
"""Get the state for a session id."""
|
||||
def _get_state(self, session_id: str):
|
||||
"""Get the state for a session id.
|
||||
|
||||
Wrapper for :meth:`StateManager.fetch_raw`
|
||||
"""
|
||||
try:
|
||||
state = self.state_man.fetch_raw(session_id)
|
||||
return state
|
||||
|
|
@ -605,7 +612,7 @@ class GuildMemberList:
|
|||
]
|
||||
}
|
||||
|
||||
states = map(self.get_state, session_ids)
|
||||
states = map(self._get_state, session_ids)
|
||||
states = filter(lambda state: state is not None, states)
|
||||
|
||||
dispatched = []
|
||||
|
|
@ -618,7 +625,7 @@ class GuildMemberList:
|
|||
|
||||
return dispatched
|
||||
|
||||
async def resync(self, session_ids: int, item_index: int) -> List[str]:
|
||||
async def _resync(self, session_ids: int, item_index: int) -> List[str]:
|
||||
"""Send a SYNC event to all states that are subscribed to an item.
|
||||
|
||||
Returns
|
||||
|
|
@ -653,13 +660,13 @@ class GuildMemberList:
|
|||
|
||||
return result
|
||||
|
||||
async def resync_by_item(self, item_index: int):
|
||||
async def _resync_by_item(self, item_index: int):
|
||||
"""Resync but only giving the item index."""
|
||||
if item_index is None:
|
||||
return []
|
||||
|
||||
return await self.resync(
|
||||
self.get_subs(item_index),
|
||||
return await self._resync(
|
||||
self._get_subs(item_index),
|
||||
item_index
|
||||
)
|
||||
|
||||
|
|
@ -722,8 +729,9 @@ class GuildMemberList:
|
|||
# send SYNCs to the state that requested
|
||||
await self._dispatch_sess([session_id], ops)
|
||||
|
||||
def get_item_index(self, user_id: Union[str, int]) -> int:
|
||||
def _get_item_index(self, user_id: Union[str, int]) -> int:
|
||||
"""Get the item index a user is on."""
|
||||
# NOTE: this is inefficient
|
||||
user_id = int(user_id)
|
||||
index = 1
|
||||
|
||||
|
|
@ -741,7 +749,7 @@ class GuildMemberList:
|
|||
|
||||
return None
|
||||
|
||||
def get_group_item_index(self, group_id: GroupID) -> int:
|
||||
def _get_group_item_index(self, group_id: GroupID) -> int:
|
||||
"""Get the item index a group is on."""
|
||||
index = 0
|
||||
|
||||
|
|
@ -753,7 +761,7 @@ class GuildMemberList:
|
|||
|
||||
return None
|
||||
|
||||
def state_is_subbed(self, item_index, session_id: str) -> bool:
|
||||
def _is_subbed(self, item_index, session_id: str) -> bool:
|
||||
"""Return if a state's ranges include the given
|
||||
item index."""
|
||||
|
||||
|
|
@ -765,15 +773,21 @@ class GuildMemberList:
|
|||
|
||||
return False
|
||||
|
||||
def get_subs(self, item_index: int) -> filter:
|
||||
def _get_subs(self, item_index: int) -> filter:
|
||||
"""Get the list of subscribed states to a given item."""
|
||||
return filter(
|
||||
lambda sess_id: self.state_is_subbed(item_index, sess_id),
|
||||
lambda sess_id: self._is_subbed(item_index, sess_id),
|
||||
self.state.keys()
|
||||
)
|
||||
|
||||
async def _pres_update_simple(self, user_id: int):
|
||||
item_index = self.get_item_index(user_id)
|
||||
"""Handler for simple presence updates.
|
||||
|
||||
Simple presence updates are just a single UPDATE operator for
|
||||
the client. usually called when a user still maintains their role
|
||||
list but changes to from online to idle/dnd and vice-versa.
|
||||
"""
|
||||
item_index = self._get_item_index(user_id)
|
||||
|
||||
if item_index is None:
|
||||
log.warning('lazy guild got invalid pres update uid={}',
|
||||
|
|
@ -781,7 +795,7 @@ class GuildMemberList:
|
|||
return []
|
||||
|
||||
item = self.items[item_index]
|
||||
session_ids = self.get_subs(item_index)
|
||||
session_ids = self._get_subs(item_index)
|
||||
|
||||
# simple update means we just give an UPDATE
|
||||
# operation
|
||||
|
|
@ -818,8 +832,8 @@ class GuildMemberList:
|
|||
|
||||
ops = []
|
||||
|
||||
old_user_index = self.get_item_index(user_id)
|
||||
old_group_index = self.get_group_item_index(old_group)
|
||||
old_user_index = self._get_item_index(user_id)
|
||||
old_group_index = self._get_group_item_index(old_group)
|
||||
|
||||
ops.append(Operation('DELETE', {
|
||||
'index': old_user_index
|
||||
|
|
@ -831,7 +845,7 @@ class GuildMemberList:
|
|||
|
||||
await self._sort_groups()
|
||||
|
||||
new_user_index = self.get_item_index(user_id)
|
||||
new_user_index = self._get_item_index(user_id)
|
||||
|
||||
ops.append(Operation('INSERT', {
|
||||
'index': new_user_index,
|
||||
|
|
@ -845,7 +859,7 @@ class GuildMemberList:
|
|||
# the first member in the group.
|
||||
if self.list.is_birth(new_group) and new_group != 'offline':
|
||||
ops.append(Operation('INSERT', {
|
||||
'index': self.get_group_item_index(new_group),
|
||||
'index': self._get_group_item_index(new_group),
|
||||
'item': {
|
||||
'group': str(new_group), 'count': 1
|
||||
}
|
||||
|
|
@ -858,8 +872,8 @@ class GuildMemberList:
|
|||
'index': old_group_index,
|
||||
}))
|
||||
|
||||
session_ids_old = list(self.get_subs(old_user_index))
|
||||
session_ids_new = list(self.get_subs(new_user_index))
|
||||
session_ids_old = list(self._get_subs(old_user_index))
|
||||
session_ids_new = list(self._get_subs(new_user_index))
|
||||
# session_ids = set(session_ids_old + session_ids_new)
|
||||
|
||||
# NOTE: this section is what a realistic implementation
|
||||
|
|
@ -876,8 +890,8 @@ class GuildMemberList:
|
|||
# )
|
||||
|
||||
# merge both results together
|
||||
return (await self.resync(session_ids_old, old_user_index) +
|
||||
await self.resync(session_ids_new, new_user_index))
|
||||
return (await self._resync(session_ids_old, old_user_index) +
|
||||
await self._resync(session_ids_new, new_user_index))
|
||||
|
||||
async def new_member(self, user_id: int):
|
||||
"""Insert a new member."""
|
||||
|
|
@ -906,7 +920,7 @@ class GuildMemberList:
|
|||
self.list.members[user_id] = member
|
||||
|
||||
# find a group for the newcomer
|
||||
group_id = await self.get_group_for_member(
|
||||
group_id = await self._get_group_for_member(
|
||||
user_id, member['roles'], pres['status'])
|
||||
|
||||
if group_id is None:
|
||||
|
|
@ -917,13 +931,13 @@ class GuildMemberList:
|
|||
self.list.data[group_id].append(user_id)
|
||||
await self._sort_groups()
|
||||
|
||||
user_index = self.get_item_index(user_id)
|
||||
user_index = self._get_item_index(user_id)
|
||||
|
||||
if not user_index:
|
||||
log.warning('lazy: new uid {} was not assigned idx',
|
||||
user_id)
|
||||
|
||||
return await self.resync_by_item(user_index)
|
||||
return await self._resync_by_item(user_index)
|
||||
|
||||
async def remove_member(self, user_id: int):
|
||||
"""Remove a member from the list."""
|
||||
|
|
@ -933,17 +947,13 @@ class GuildMemberList:
|
|||
return
|
||||
|
||||
# we need the old index to resync later on
|
||||
old_idx = self.get_item_index(user_id)
|
||||
|
||||
def is_valid_state(session_id):
|
||||
state = self.get_state(session_id)
|
||||
return state.user_id != user_id
|
||||
old_idx = self._get_item_index(user_id)
|
||||
|
||||
# for now, remove any of the users' subscribed states
|
||||
state_keys = self.state.keys()
|
||||
|
||||
for session_id in state_keys:
|
||||
state = self.get_state(session_id)
|
||||
state = self._get_state(session_id)
|
||||
|
||||
# if unknown state, remove from the subscriber list
|
||||
if state is None:
|
||||
|
|
@ -978,7 +988,7 @@ class GuildMemberList:
|
|||
log.warning('lazy: unknown member uid {}', user_id)
|
||||
return
|
||||
|
||||
group_id = await self.get_group_for_member(
|
||||
group_id = await self._get_group_for_member(
|
||||
user_id, member['roles'], pres['status'])
|
||||
|
||||
if not group_id:
|
||||
|
|
@ -992,7 +1002,7 @@ class GuildMemberList:
|
|||
return
|
||||
|
||||
# tell everyone about the removal.
|
||||
await self.resync_by_item(old_idx)
|
||||
await self._resync_by_item(old_idx)
|
||||
|
||||
async def update_user(self, user_id: int):
|
||||
"""Called for user updates such as avatar or username."""
|
||||
|
|
@ -1009,8 +1019,8 @@ class GuildMemberList:
|
|||
await self.storage.get_user(user_id)
|
||||
|
||||
# redispatch
|
||||
user_idx = self.get_item_index(user_id)
|
||||
return await self.resync_by_item(user_idx)
|
||||
user_idx = self._get_item_index(user_id)
|
||||
return await self._resync_by_item(user_idx)
|
||||
|
||||
async def pres_update(self, user_id: int,
|
||||
partial_presence: Presence):
|
||||
|
|
@ -1075,7 +1085,7 @@ class GuildMemberList:
|
|||
|
||||
# calculate a possible new group
|
||||
# TODO: handle when new_group is None (member loses perms)
|
||||
new_group = await self.get_group_for_member(
|
||||
new_group = await self._get_group_for_member(
|
||||
user_id, roles, status)
|
||||
|
||||
log.debug('pres update: gid={} cid={} old_g={} new_g={}',
|
||||
|
|
@ -1171,13 +1181,13 @@ class GuildMemberList:
|
|||
"""
|
||||
role_id = int(role['id'])
|
||||
|
||||
old_index = self.get_group_item_index(role_id)
|
||||
old_index = self._get_group_item_index(role_id)
|
||||
|
||||
if not old_index:
|
||||
log.warning('lazy role_pos_update: unknown group {}', role_id)
|
||||
return
|
||||
|
||||
old_sessions = list(self.get_subs(old_index))
|
||||
old_sessions = list(self._get_subs(old_index))
|
||||
|
||||
groups_idx = self._get_role_as_group_idx(role_id)
|
||||
if groups_idx is None:
|
||||
|
|
@ -1202,10 +1212,10 @@ class GuildMemberList:
|
|||
[g.gid for g in new_groups])
|
||||
|
||||
self.list.groups = new_groups
|
||||
new_index = self.get_group_item_index(role_id)
|
||||
new_index = self._get_group_item_index(role_id)
|
||||
|
||||
return (await self.resync(old_sessions, old_index) +
|
||||
await self.resync_by_item(new_index))
|
||||
return (await self._resync(old_sessions, old_index) +
|
||||
await self._resync_by_item(new_index))
|
||||
|
||||
async def role_update(self, role: dict):
|
||||
"""Update a role.
|
||||
|
|
@ -1267,7 +1277,7 @@ class GuildMemberList:
|
|||
# states we'll resend the list info to.
|
||||
|
||||
# find the item id for the group info
|
||||
role_item_index = self.get_group_item_index(role_id)
|
||||
role_item_index = self._get_group_item_index(role_id)
|
||||
|
||||
# we only resync when we actually have an item to resync
|
||||
# we don't have items to resync when we:
|
||||
|
|
@ -1279,7 +1289,7 @@ class GuildMemberList:
|
|||
|
||||
# using a filter object would cause problems
|
||||
# as we only resync AFTER we delete the group
|
||||
sess_ids_resync = (list(self.get_subs(role_item_index))
|
||||
sess_ids_resync = (list(self._get_subs(role_item_index))
|
||||
if role_item_index is not None
|
||||
else [])
|
||||
|
||||
|
|
@ -1328,7 +1338,7 @@ class GuildMemberList:
|
|||
log.debug('there are {} session ids to resync (for item {})',
|
||||
len(sess_ids_resync), role_item_index)
|
||||
|
||||
return await self.resync(sess_ids_resync, role_item_index)
|
||||
return await self._resync(sess_ids_resync, role_item_index)
|
||||
|
||||
async def chan_update(self):
|
||||
"""Called then a channel's data has been updated."""
|
||||
|
|
@ -1359,7 +1369,7 @@ class GuildMemberList:
|
|||
self.guild_id = None
|
||||
self.channel_id = None
|
||||
self.main = None
|
||||
self.list = MemberList
|
||||
self._set_empty_list()
|
||||
self.state = {}
|
||||
|
||||
|
||||
|
|
@ -1412,6 +1422,7 @@ class LazyGuildDispatcher(Dispatcher):
|
|||
))
|
||||
|
||||
async def unsub(self, chan_id, session_id):
|
||||
"""Unsubscribe a session from the list."""
|
||||
gml = await self.get_gml(chan_id)
|
||||
gml.unsub(session_id)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue