pubsub.lazy_guilds: major refactor

This makes the whole process of generating a member list
easier to understand and modify (from my point of view).

The actual event dispatching functionality is not on this
commit.

 - permissions: add optional storage kwarg
This commit is contained in:
Luna Mendes 2018-11-05 22:04:48 -03:00
parent 7c274f0f70
commit 04d89a2214
2 changed files with 237 additions and 149 deletions

View File

@ -65,7 +65,7 @@ class Permissions(ctypes.Union):
ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111) ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111)
async def base_permissions(member_id, guild_id) -> Permissions: async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
"""Compute the base permissions for a given user. """Compute the base permissions for a given user.
Base permissions are Base permissions are
@ -75,7 +75,11 @@ async def base_permissions(member_id, guild_id) -> Permissions:
This will give ALL_PERMISSIONS if base permissions This will give ALL_PERMISSIONS if base permissions
has the Administrator bit set. has the Administrator bit set.
""" """
owner_id = await app.db.fetchval("""
if not storage:
storage = app.storage
owner_id = await storage.db.fetchval("""
SELECT owner_id SELECT owner_id
FROM guilds FROM guilds
WHERE id = $1 WHERE id = $1
@ -85,7 +89,7 @@ async def base_permissions(member_id, guild_id) -> Permissions:
return ALL_PERMISSIONS return ALL_PERMISSIONS
# get permissions for @everyone # get permissions for @everyone
everyone_perms = await app.db.fetchval(""" everyone_perms = await storage.db.fetchval("""
SELECT permissions SELECT permissions
FROM roles FROM roles
WHERE guild_id = $1 WHERE guild_id = $1
@ -93,7 +97,7 @@ async def base_permissions(member_id, guild_id) -> Permissions:
permissions = Permissions(everyone_perms) permissions = Permissions(everyone_perms)
role_ids = await app.db.fetch(""" role_ids = await storage.db.fetch("""
SELECT role_id SELECT role_id
FROM member_roles FROM member_roles
WHERE guild_id = $1 AND user_id = $2 WHERE guild_id = $1 AND user_id = $2
@ -102,7 +106,7 @@ async def base_permissions(member_id, guild_id) -> Permissions:
role_perms = [] role_perms = []
for row in role_ids: for row in role_ids:
rperm = await app.db.fetchval(""" rperm = await storage.db.fetchval("""
SELECT permissions SELECT permissions
FROM roles FROM roles
WHERE id = $1 WHERE id = $1
@ -119,7 +123,7 @@ async def base_permissions(member_id, guild_id) -> Permissions:
return permissions return permissions
def _mix(perms: Permissions, overwrite: dict) -> Permissions: def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions:
# we make a copy of the binary representation # we make a copy of the binary representation
# so we don't modify the old perms in-place # so we don't modify the old perms in-place
# which could be an unwanted side-effect # which could be an unwanted side-effect
@ -134,20 +138,22 @@ def _mix(perms: Permissions, overwrite: dict) -> Permissions:
return Permissions(result) return Permissions(result)
def _overwrite_mix(perms: Permissions, overwrites: dict, def overwrite_find_mix(perms: Permissions, overwrites: dict,
target_id: int) -> Permissions: target_id: int) -> Permissions:
overwrite = overwrites.get(target_id) overwrite = overwrites.get(target_id)
if overwrite: if overwrite:
# only mix if overwrite found # only mix if overwrite found
return _mix(perms, overwrite) return overwrite_mix(perms, overwrite)
return perms return perms
async def compute_overwrites(base_perms, user_id, channel_id: int, async def compute_overwrites(base_perms, user_id, channel_id: int,
guild_id: int = None): guild_id: int = None, storage=None):
"""Compute the permissions in the context of a channel.""" """Compute the permissions in the context of a channel."""
if not storage:
storage = app.storage
if base_perms.bits.administrator: if base_perms.bits.administrator:
return ALL_PERMISSIONS return ALL_PERMISSIONS
@ -155,21 +161,21 @@ async def compute_overwrites(base_perms, user_id, channel_id: int,
perms = base_perms perms = base_perms
# list of overwrites # list of overwrites
overwrites = await app.storage.chan_overwrites(channel_id) overwrites = await storage.chan_overwrites(channel_id)
if not guild_id: if not guild_id:
guild_id = await app.storage.guild_from_channel(channel_id) guild_id = await storage.guild_from_channel(channel_id)
# make it a map for better usage # make it a map for better usage
overwrites = {o['id']: o for o in overwrites} overwrites = {o['id']: o for o in overwrites}
perms = _overwrite_mix(perms, overwrites, guild_id) perms = overwrite_find_mix(perms, overwrites, guild_id)
# apply role specific overwrites # apply role specific overwrites
allow, deny = 0, 0 allow, deny = 0, 0
# fetch roles from user and convert to int # fetch roles from user and convert to int
role_ids = await app.storage.get_member_role_ids(guild_id, user_id) role_ids = await storage.get_member_role_ids(guild_id, user_id)
role_ids = map(int, role_ids) role_ids = map(int, role_ids)
# make the allow and deny binaries # make the allow and deny binaries
@ -180,26 +186,29 @@ async def compute_overwrites(base_perms, user_id, channel_id: int,
deny |= overwrite['deny'] deny |= overwrite['deny']
# final step for roles: mix # final step for roles: mix
perms = _mix(perms, { perms = overwrite_mix(perms, {
'allow': allow, 'allow': allow,
'deny': deny 'deny': deny
}) })
# apply member specific overwrites # apply member specific overwrites
perms = _overwrite_mix(perms, overwrites, user_id) perms = overwrite_find_mix(perms, overwrites, user_id)
return perms return perms
async def get_permissions(member_id, channel_id): async def get_permissions(member_id, channel_id, *, storage=None):
"""Get all the permissions for a user in a channel.""" """Get all the permissions for a user in a channel."""
guild_id = await app.storage.guild_from_channel(channel_id) if not storage:
storage = app.storage
guild_id = await storage.guild_from_channel(channel_id)
# for non guild channels # for non guild channels
if not guild_id: if not guild_id:
return ALL_PERMISSIONS return ALL_PERMISSIONS
base_perms = await base_permissions(member_id, guild_id) base_perms = await base_permissions(member_id, guild_id, storage)
return await compute_overwrites(base_perms, member_id, return await compute_overwrites(base_perms, member_id,
channel_id, guild_id) channel_id, guild_id, storage)

View File

@ -2,15 +2,57 @@
Main code for Lazy Guild implementation in litecord. Main code for Lazy Guild implementation in litecord.
""" """
import pprint import pprint
from dataclasses import dataclass, asdict
from collections import defaultdict from collections import defaultdict
from typing import Any, List, Dict from typing import Any, List, Dict, Union
from logbook import Logger from logbook import Logger
from .dispatcher import Dispatcher from litecord.pubsub.dispatcher import Dispatcher
from litecord.permissions import (
Permissions, overwrite_find_mix, get_permissions
)
log = Logger(__name__) log = Logger(__name__)
GroupID = Union[int, str]
Presence = Dict[str, Any]
@dataclass
class GroupInfo:
gid: GroupID
name: str
position: int
permissions: Permissions
@dataclass
class MemberList:
groups: List[GroupInfo] = None
group_info: Dict[GroupID, GroupInfo] = None
data: Dict[GroupID, Presence] = None
overwrites: Dict[int, Dict[str, Any]] = None
def __bool__(self):
"""Return if the current member list is fully initialized."""
list_dict = asdict(self)
return all(v is not None for v in list_dict.values())
def __iter__(self):
"""Iterate over all groups in the correct order.
Yields a tuple containing :class:`GroupInfo` and
the List[Presence] for the group.
"""
for group in self.groups:
yield group, self.data[group.gid]
def _to_simple_group(presence: dict) -> str:
"""Return a simple group (not a role), given a presence."""
return 'offline' if presence['status'] == 'offline' else 'online'
class GuildMemberList: class GuildMemberList:
"""This class stores the current member list information """This class stores the current member list information
@ -41,7 +83,6 @@ class GuildMemberList:
""" """
def __init__(self, guild_id: int, def __init__(self, guild_id: int,
channel_id: int, main_lg): channel_id: int, main_lg):
self.main_lg = main_lg
self.guild_id = guild_id self.guild_id = guild_id
self.channel_id = channel_id self.channel_id = channel_id
@ -52,167 +93,205 @@ class GuildMemberList:
self.presence = main.app.presence self.presence = main.app.presence
self.state_man = main.app.state_manager self.state_man = main.app.state_manager
self.member_list = None self.list = MemberList(None, None, None, None)
self.items = None
#: holds the state of subscribed shards #: holds the state of subscribed shards
# to this channels' member list # to this channels' member list
self.state = set() self.state = set()
def _set_empty_list(self):
self.list = MemberList(None, None, None, None)
async def _init_check(self): async def _init_check(self):
"""Check if the member list is initialized before """Check if the member list is initialized before
messing with it.""" messing with it."""
if self.member_list is None: if not self.list:
await self._init_member_list() await self._init_member_list()
async def get_roles(self) -> List[Dict[str, Any]]: async def _fetch_overwrites(self):
overwrites = await self.storage.chan_overwrites(self.channel_id)
overwrites = {int(ov['id']): ov for ov in overwrites}
self.list.overwrites = overwrites
def _calc_member_group(self, roles: List[int], status: str):
"""Calculate the best fitting group for a member,
given their roles and their current status."""
try:
# the first group in the list
# that the member is entitled to is
# the selected group for the member.
group_id = next(g.gid for g in self.list.groups
if g.gid in roles)
except StopIteration:
# no group was found, so we fallback
# to simple group"
group_id = _to_simple_group({'status': status})
return group_id
async def get_roles(self) -> List[GroupInfo]:
"""Get role information, but only: """Get role information, but only:
- the ID - the ID
- the name - the name
- the position - the position
- the permissions
of all HOISTED roles.""" of all HOISTED roles AND roles that
# TODO: write own query for this have permissions to read the channel
# TODO: calculate channel overrides being referred to this :class:`GuildMemberList`
roles = await self.storage.get_role_data(self.guild_id) instance.
return [{ The list is sorted by position.
'id': role['id'], """
'name': role['name'], roledata = await self.storage.db.fetch("""
'position': role['position'] SELECT id, name, hoist, position, permissions
} for role in roles if role['hoist']] FROM roles
WHERE guild_id = $1
""", self.guild_id)
hoisted = [
GroupInfo(row['id'], row['name'],
row['position'], row['permissions'])
for row in roledata if row['hoist']
]
# sort role list by position
hoisted = sorted(hoisted, key=lambda group: group.position)
# we need to store them for later on
# for members
await self._fetch_overwrites()
def _can_read_chan(group: GroupInfo):
# get the base role perms
role_perms = group.permissions
# then the final perms for that role if
# any overwrite exists in the channel
final_perms = overwrite_find_mix(
role_perms, self.list.overwrites, group.gid)
# update the group's permissions
# with the mixed ones
group.permissions = final_perms
# if the role can read messages, then its
# part of the group.
return final_perms.bits.read_messages
return list(filter(_can_read_chan, hoisted))
async def set_groups(self):
"""Get the groups for the member list."""
role_groups = await self.get_roles()
role_ids = [g.gid for g in role_groups]
self.list.groups = role_ids + ['online', 'offline']
# inject default groups 'online' and 'offline'
self.list.groups = role_ids + [
GroupInfo('online', 'online', -1, -1),
GroupInfo('offline', 'offline', -1, -1)
]
self.list.group_info = {g.gid: g for g in role_groups}
async def _pass_1(self, guild_presences: List[Presence]):
"""First pass on generating the member list.
This assigns all presences a single group.
"""
for presence in guild_presences:
member_id = int(presence['user']['id'])
# list of roles for the member
member_roles = list(map(int, presence['roles']))
# get the member's permissions relative to the channel
# (accounting for channel overwrites)
member_perms = await get_permissions(
member_id, self.channel_id, storage=self.storage)
if not member_perms.bits.read_messages:
continue
# if the member is offline, we
# default give them the offline group.
status = presence['status']
group_id = ('offline' if status == 'offline'
else self._calc_member_group(member_roles, status))
self.list.data[group_id].append(presence)
async def _sort_groups(self):
members = await self.storage.get_member_data(self.guild_id)
# make a dictionary of member ids to nicknames
# so we don't need to keep querying the db on
# every loop iteration
member_nicks = {m['user']['id']: m.get('nick')
for m in members}
for group_members in self.list.data.values():
def display_name(presence: Presence) -> str:
uid = presence['user']['id']
uname = presence['user']['username']
nick = member_nicks.get(uid)
return nick or uname
# this should update the list in-place
group_members.sort(key=display_name)
async def _init_member_list(self): async def _init_member_list(self):
"""Fill in :attr:`GuildMemberList.member_list` """Generate the main member list with groups."""
with information about the guilds' members."""
member_ids = await self.storage.get_member_ids(self.guild_id) member_ids = await self.storage.get_member_ids(self.guild_id)
guild_presences = await self.presence.guild_presences( guild_presences = await self.presence.guild_presences(
member_ids, self.guild_id) member_ids, self.guild_id)
guild_roles = await self.get_roles() await self.set_groups()
# sort by position
guild_roles.sort(key=lambda role: role['position'])
roleids = [r['id'] for r in guild_roles]
# groups are:
# - roles that are hoisted
# - "online" and "offline", with "online"
# being for people without any roles.
groups = roleids + ['online', 'offline']
log.debug('{} presences, {} groups', log.debug('{} presences, {} groups',
len(guild_presences), len(groups)) len(guild_presences),
len(self.list.groups))
group_data = {group: [] for group in groups} self.list.data = {group.gid: [] for group in self.list.groups}
print('group data', group_data) # first pass: set which presences
# go to which groups
await self._pass_1(guild_presences)
def _try_hier(role_id: str, roleids: list): # second pass: sort each group's members
"""Try to fetch a role's position in the hierarchy""" # by the display name
try: await self._sort_groups()
return roleids.index(role_id)
except ValueError:
# the given role isn't on a group
# so it doesn't count for our purposes.
return 0
for presence in guild_presences: @property
# simple group (online or offline) def items(self) -> list:
# we'll decide on the best group for the presence later on """Main items list."""
simple_group = ('offline'
if presence['status'] == 'offline'
else 'online')
# get the best possible role # TODO: maybe make this stored in the list
roles = sorted( # so we don't need to keep regenning?
presence['roles'],
key=lambda role_id: _try_hier(role_id, roleids)
)
try: if not self.list:
best_role_id = roles[0]
except IndexError:
# no hoisted roles exist in the guild, assign
# the @everyone role
best_role_id = str(self.guild_id)
print('best role', best_role_id, str(self.guild_id))
print('simple group assign', simple_group)
# if the best role is literally the @everyone role,
# this user has no hoisted roles
if best_role_id == str(self.guild_id):
# this user has no roles, put it on online/offline
group_data[simple_group].append(presence)
continue
# this user has a best_role that isn't the
# @everyone role, so we'll put them in the respective group
try:
group_data[best_role_id].append(presence)
except KeyError:
group_data[simple_group].append(presence)
# go through each group and sort the resulting members by display name
members = await self.storage.get_member_data(self.guild_id)
member_nicks = {member['user']['id']: member.get('nick')
for member in members}
# now we'll sort each group by their display name
# (can be their current nickname OR their username
# if no nickname is set)
print('pre-sorted group data')
pprint.pprint(group_data)
for _, group_list in group_data.items():
def display_name(presence: dict) -> str:
uid = presence['user']['id']
uname = presence['user']['username']
nick = member_nicks[uid]
return nick or uname
group_list.sort(key=display_name)
pprint.pprint(group_data)
self.member_list = {
'groups': groups,
'data': group_data
}
def get_items(self) -> list:
"""Generate the main items list,"""
if self.member_list is None:
return [] return []
if self.items:
return self.items
groups = self.member_list['groups']
res = [] res = []
for group in groups:
members = self.member_list['data'][group]
# NOTE: maybe use map()?
for group, presences in self.list:
res.append({ res.append({
'group': { 'group': {
'id': group, 'id': group.gid,
'count': len(members), 'count': len(presences),
} }
}) })
for member in members: for presence in presences:
res.append({ res.append({
'member': member 'member': presence
}) })
self.items = res
return res return res
async def sub(self, session_id: str): async def sub(self, session_id: str):
@ -230,7 +309,7 @@ class GuildMemberList:
# uninitialized) for a future subscriber. # uninitialized) for a future subscriber.
if not self.state: if not self.state:
self.member_list = None self._set_empty_list()
async def shard_query(self, session_id: str, ranges: list): async def shard_query(self, session_id: str, ranges: list):
"""Send a GUILD_MEMBER_LIST_UPDATE event """Send a GUILD_MEMBER_LIST_UPDATE event
@ -259,7 +338,7 @@ class GuildMemberList:
# subscribe the state to this list # subscribe the state to this list
await self.sub(session_id) await self.sub(session_id)
# TODO: subscribe shard to 'everyone' # TODO: subscribe shard to the 'everyone' member list
# and forward the query to that list # and forward the query to that list
reply = { reply = {
@ -273,9 +352,9 @@ class GuildMemberList:
'groups': [ 'groups': [
{ {
'count': len(self.member_list['data'][group]), 'count': len(presences),
'id': group 'id': group.gid
} for group in self.member_list['groups'] } for group, presences in self.list
], ],
'ops': [], 'ops': [],
@ -288,12 +367,12 @@ class GuildMemberList:
if itemcount < 0: if itemcount < 0:
continue continue
items = self.get_items() # TODO: subscribe user to the slice
reply['ops'].append({ reply['ops'].append({
'op': 'SYNC', 'op': 'SYNC',
'range': [start, end], 'range': [start, end],
'items': items[start:end], 'items': self.items[start:end],
}) })
# the first GUILD_MEMBER_LIST_UPDATE for a shard # the first GUILD_MEMBER_LIST_UPDATE for a shard