mirror of https://gitlab.com/litecord/litecord.git
pubsub.lazy_guilds: major refactor
This makes the whole process of generating a member list easier to understand and modify (from my point of view). The actual event dispatching functionality is not on this commit. - permissions: add optional storage kwarg
This commit is contained in:
parent
7c274f0f70
commit
04d89a2214
|
|
@ -65,7 +65,7 @@ class Permissions(ctypes.Union):
|
|||
ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111)
|
||||
|
||||
|
||||
async def base_permissions(member_id, guild_id) -> Permissions:
|
||||
async def base_permissions(member_id, guild_id, storage=None) -> Permissions:
|
||||
"""Compute the base permissions for a given user.
|
||||
|
||||
Base permissions are
|
||||
|
|
@ -75,7 +75,11 @@ async def base_permissions(member_id, guild_id) -> Permissions:
|
|||
This will give ALL_PERMISSIONS if base permissions
|
||||
has the Administrator bit set.
|
||||
"""
|
||||
owner_id = await app.db.fetchval("""
|
||||
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
owner_id = await storage.db.fetchval("""
|
||||
SELECT owner_id
|
||||
FROM guilds
|
||||
WHERE id = $1
|
||||
|
|
@ -85,7 +89,7 @@ async def base_permissions(member_id, guild_id) -> Permissions:
|
|||
return ALL_PERMISSIONS
|
||||
|
||||
# get permissions for @everyone
|
||||
everyone_perms = await app.db.fetchval("""
|
||||
everyone_perms = await storage.db.fetchval("""
|
||||
SELECT permissions
|
||||
FROM roles
|
||||
WHERE guild_id = $1
|
||||
|
|
@ -93,7 +97,7 @@ async def base_permissions(member_id, guild_id) -> Permissions:
|
|||
|
||||
permissions = Permissions(everyone_perms)
|
||||
|
||||
role_ids = await app.db.fetch("""
|
||||
role_ids = await storage.db.fetch("""
|
||||
SELECT role_id
|
||||
FROM member_roles
|
||||
WHERE guild_id = $1 AND user_id = $2
|
||||
|
|
@ -102,7 +106,7 @@ async def base_permissions(member_id, guild_id) -> Permissions:
|
|||
role_perms = []
|
||||
|
||||
for row in role_ids:
|
||||
rperm = await app.db.fetchval("""
|
||||
rperm = await storage.db.fetchval("""
|
||||
SELECT permissions
|
||||
FROM roles
|
||||
WHERE id = $1
|
||||
|
|
@ -119,7 +123,7 @@ async def base_permissions(member_id, guild_id) -> Permissions:
|
|||
return permissions
|
||||
|
||||
|
||||
def _mix(perms: Permissions, overwrite: dict) -> Permissions:
|
||||
def overwrite_mix(perms: Permissions, overwrite: dict) -> Permissions:
|
||||
# we make a copy of the binary representation
|
||||
# so we don't modify the old perms in-place
|
||||
# which could be an unwanted side-effect
|
||||
|
|
@ -134,20 +138,22 @@ def _mix(perms: Permissions, overwrite: dict) -> Permissions:
|
|||
return Permissions(result)
|
||||
|
||||
|
||||
def _overwrite_mix(perms: Permissions, overwrites: dict,
|
||||
target_id: int) -> Permissions:
|
||||
def overwrite_find_mix(perms: Permissions, overwrites: dict,
|
||||
target_id: int) -> Permissions:
|
||||
overwrite = overwrites.get(target_id)
|
||||
|
||||
if overwrite:
|
||||
# only mix if overwrite found
|
||||
return _mix(perms, overwrite)
|
||||
return overwrite_mix(perms, overwrite)
|
||||
|
||||
return perms
|
||||
|
||||
|
||||
async def compute_overwrites(base_perms, user_id, channel_id: int,
|
||||
guild_id: int = None):
|
||||
guild_id: int = None, storage=None):
|
||||
"""Compute the permissions in the context of a channel."""
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
if base_perms.bits.administrator:
|
||||
return ALL_PERMISSIONS
|
||||
|
|
@ -155,21 +161,21 @@ async def compute_overwrites(base_perms, user_id, channel_id: int,
|
|||
perms = base_perms
|
||||
|
||||
# list of overwrites
|
||||
overwrites = await app.storage.chan_overwrites(channel_id)
|
||||
overwrites = await storage.chan_overwrites(channel_id)
|
||||
|
||||
if not guild_id:
|
||||
guild_id = await app.storage.guild_from_channel(channel_id)
|
||||
guild_id = await storage.guild_from_channel(channel_id)
|
||||
|
||||
# make it a map for better usage
|
||||
overwrites = {o['id']: o for o in overwrites}
|
||||
|
||||
perms = _overwrite_mix(perms, overwrites, guild_id)
|
||||
perms = overwrite_find_mix(perms, overwrites, guild_id)
|
||||
|
||||
# apply role specific overwrites
|
||||
allow, deny = 0, 0
|
||||
|
||||
# fetch roles from user and convert to int
|
||||
role_ids = await app.storage.get_member_role_ids(guild_id, user_id)
|
||||
role_ids = await storage.get_member_role_ids(guild_id, user_id)
|
||||
role_ids = map(int, role_ids)
|
||||
|
||||
# make the allow and deny binaries
|
||||
|
|
@ -180,26 +186,29 @@ async def compute_overwrites(base_perms, user_id, channel_id: int,
|
|||
deny |= overwrite['deny']
|
||||
|
||||
# final step for roles: mix
|
||||
perms = _mix(perms, {
|
||||
perms = overwrite_mix(perms, {
|
||||
'allow': allow,
|
||||
'deny': deny
|
||||
})
|
||||
|
||||
# apply member specific overwrites
|
||||
perms = _overwrite_mix(perms, overwrites, user_id)
|
||||
perms = overwrite_find_mix(perms, overwrites, user_id)
|
||||
|
||||
return perms
|
||||
|
||||
|
||||
async def get_permissions(member_id, channel_id):
|
||||
async def get_permissions(member_id, channel_id, *, storage=None):
|
||||
"""Get all the permissions for a user in a channel."""
|
||||
guild_id = await app.storage.guild_from_channel(channel_id)
|
||||
if not storage:
|
||||
storage = app.storage
|
||||
|
||||
guild_id = await storage.guild_from_channel(channel_id)
|
||||
|
||||
# for non guild channels
|
||||
if not guild_id:
|
||||
return ALL_PERMISSIONS
|
||||
|
||||
base_perms = await base_permissions(member_id, guild_id)
|
||||
base_perms = await base_permissions(member_id, guild_id, storage)
|
||||
|
||||
return await compute_overwrites(base_perms, member_id,
|
||||
channel_id, guild_id)
|
||||
channel_id, guild_id, storage)
|
||||
|
|
|
|||
|
|
@ -2,15 +2,57 @@
|
|||
Main code for Lazy Guild implementation in litecord.
|
||||
"""
|
||||
import pprint
|
||||
from dataclasses import dataclass, asdict
|
||||
from collections import defaultdict
|
||||
from typing import Any, List, Dict
|
||||
from typing import Any, List, Dict, Union
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import Dispatcher
|
||||
from litecord.pubsub.dispatcher import Dispatcher
|
||||
from litecord.permissions import (
|
||||
Permissions, overwrite_find_mix, get_permissions
|
||||
)
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
GroupID = Union[int, str]
|
||||
Presence = Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupInfo:
|
||||
gid: GroupID
|
||||
name: str
|
||||
position: int
|
||||
permissions: Permissions
|
||||
|
||||
|
||||
@dataclass
|
||||
class MemberList:
|
||||
groups: List[GroupInfo] = None
|
||||
group_info: Dict[GroupID, GroupInfo] = None
|
||||
data: Dict[GroupID, Presence] = None
|
||||
overwrites: Dict[int, Dict[str, Any]] = None
|
||||
|
||||
def __bool__(self):
|
||||
"""Return if the current member list is fully initialized."""
|
||||
list_dict = asdict(self)
|
||||
return all(v is not None for v in list_dict.values())
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over all groups in the correct order.
|
||||
|
||||
Yields a tuple containing :class:`GroupInfo` and
|
||||
the List[Presence] for the group.
|
||||
"""
|
||||
for group in self.groups:
|
||||
yield group, self.data[group.gid]
|
||||
|
||||
|
||||
def _to_simple_group(presence: dict) -> str:
|
||||
"""Return a simple group (not a role), given a presence."""
|
||||
return 'offline' if presence['status'] == 'offline' else 'online'
|
||||
|
||||
|
||||
class GuildMemberList:
|
||||
"""This class stores the current member list information
|
||||
|
|
@ -41,7 +83,6 @@ class GuildMemberList:
|
|||
"""
|
||||
def __init__(self, guild_id: int,
|
||||
channel_id: int, main_lg):
|
||||
self.main_lg = main_lg
|
||||
self.guild_id = guild_id
|
||||
self.channel_id = channel_id
|
||||
|
||||
|
|
@ -52,167 +93,205 @@ class GuildMemberList:
|
|||
self.presence = main.app.presence
|
||||
self.state_man = main.app.state_manager
|
||||
|
||||
self.member_list = None
|
||||
self.items = None
|
||||
self.list = MemberList(None, None, None, None)
|
||||
|
||||
#: holds the state of subscribed shards
|
||||
# to this channels' member list
|
||||
self.state = set()
|
||||
|
||||
def _set_empty_list(self):
|
||||
self.list = MemberList(None, None, None, None)
|
||||
|
||||
async def _init_check(self):
|
||||
"""Check if the member list is initialized before
|
||||
messing with it."""
|
||||
if self.member_list is None:
|
||||
if not self.list:
|
||||
await self._init_member_list()
|
||||
|
||||
async def get_roles(self) -> List[Dict[str, Any]]:
|
||||
async def _fetch_overwrites(self):
|
||||
overwrites = await self.storage.chan_overwrites(self.channel_id)
|
||||
overwrites = {int(ov['id']): ov for ov in overwrites}
|
||||
self.list.overwrites = overwrites
|
||||
|
||||
def _calc_member_group(self, roles: List[int], status: str):
|
||||
"""Calculate the best fitting group for a member,
|
||||
given their roles and their current status."""
|
||||
try:
|
||||
# the first group in the list
|
||||
# that the member is entitled to is
|
||||
# the selected group for the member.
|
||||
group_id = next(g.gid for g in self.list.groups
|
||||
if g.gid in roles)
|
||||
except StopIteration:
|
||||
# no group was found, so we fallback
|
||||
# to simple group"
|
||||
group_id = _to_simple_group({'status': status})
|
||||
|
||||
return group_id
|
||||
|
||||
async def get_roles(self) -> List[GroupInfo]:
|
||||
"""Get role information, but only:
|
||||
- the ID
|
||||
- the name
|
||||
- the position
|
||||
- the permissions
|
||||
|
||||
of all HOISTED roles."""
|
||||
# TODO: write own query for this
|
||||
# TODO: calculate channel overrides
|
||||
roles = await self.storage.get_role_data(self.guild_id)
|
||||
of all HOISTED roles AND roles that
|
||||
have permissions to read the channel
|
||||
being referred to this :class:`GuildMemberList`
|
||||
instance.
|
||||
|
||||
return [{
|
||||
'id': role['id'],
|
||||
'name': role['name'],
|
||||
'position': role['position']
|
||||
} for role in roles if role['hoist']]
|
||||
The list is sorted by position.
|
||||
"""
|
||||
roledata = await self.storage.db.fetch("""
|
||||
SELECT id, name, hoist, position, permissions
|
||||
FROM roles
|
||||
WHERE guild_id = $1
|
||||
""", self.guild_id)
|
||||
|
||||
hoisted = [
|
||||
GroupInfo(row['id'], row['name'],
|
||||
row['position'], row['permissions'])
|
||||
for row in roledata if row['hoist']
|
||||
]
|
||||
|
||||
# sort role list by position
|
||||
hoisted = sorted(hoisted, key=lambda group: group.position)
|
||||
|
||||
# we need to store them for later on
|
||||
# for members
|
||||
await self._fetch_overwrites()
|
||||
|
||||
def _can_read_chan(group: GroupInfo):
|
||||
# get the base role perms
|
||||
role_perms = group.permissions
|
||||
|
||||
# then the final perms for that role if
|
||||
# any overwrite exists in the channel
|
||||
final_perms = overwrite_find_mix(
|
||||
role_perms, self.list.overwrites, group.gid)
|
||||
|
||||
# update the group's permissions
|
||||
# with the mixed ones
|
||||
group.permissions = final_perms
|
||||
|
||||
# if the role can read messages, then its
|
||||
# part of the group.
|
||||
return final_perms.bits.read_messages
|
||||
|
||||
return list(filter(_can_read_chan, hoisted))
|
||||
|
||||
async def set_groups(self):
|
||||
"""Get the groups for the member list."""
|
||||
role_groups = await self.get_roles()
|
||||
role_ids = [g.gid for g in role_groups]
|
||||
|
||||
self.list.groups = role_ids + ['online', 'offline']
|
||||
|
||||
# inject default groups 'online' and 'offline'
|
||||
self.list.groups = role_ids + [
|
||||
GroupInfo('online', 'online', -1, -1),
|
||||
GroupInfo('offline', 'offline', -1, -1)
|
||||
]
|
||||
self.list.group_info = {g.gid: g for g in role_groups}
|
||||
|
||||
async def _pass_1(self, guild_presences: List[Presence]):
|
||||
"""First pass on generating the member list.
|
||||
|
||||
This assigns all presences a single group.
|
||||
"""
|
||||
for presence in guild_presences:
|
||||
member_id = int(presence['user']['id'])
|
||||
|
||||
# list of roles for the member
|
||||
member_roles = list(map(int, presence['roles']))
|
||||
|
||||
# get the member's permissions relative to the channel
|
||||
# (accounting for channel overwrites)
|
||||
member_perms = await get_permissions(
|
||||
member_id, self.channel_id, storage=self.storage)
|
||||
|
||||
if not member_perms.bits.read_messages:
|
||||
continue
|
||||
|
||||
# if the member is offline, we
|
||||
# default give them the offline group.
|
||||
status = presence['status']
|
||||
group_id = ('offline' if status == 'offline'
|
||||
else self._calc_member_group(member_roles, status))
|
||||
|
||||
self.list.data[group_id].append(presence)
|
||||
|
||||
async def _sort_groups(self):
|
||||
members = await self.storage.get_member_data(self.guild_id)
|
||||
|
||||
# make a dictionary of member ids to nicknames
|
||||
# so we don't need to keep querying the db on
|
||||
# every loop iteration
|
||||
member_nicks = {m['user']['id']: m.get('nick')
|
||||
for m in members}
|
||||
|
||||
for group_members in self.list.data.values():
|
||||
def display_name(presence: Presence) -> str:
|
||||
uid = presence['user']['id']
|
||||
|
||||
uname = presence['user']['username']
|
||||
nick = member_nicks.get(uid)
|
||||
|
||||
return nick or uname
|
||||
|
||||
# this should update the list in-place
|
||||
group_members.sort(key=display_name)
|
||||
|
||||
async def _init_member_list(self):
|
||||
"""Fill in :attr:`GuildMemberList.member_list`
|
||||
with information about the guilds' members."""
|
||||
"""Generate the main member list with groups."""
|
||||
member_ids = await self.storage.get_member_ids(self.guild_id)
|
||||
|
||||
guild_presences = await self.presence.guild_presences(
|
||||
member_ids, self.guild_id)
|
||||
|
||||
guild_roles = await self.get_roles()
|
||||
|
||||
# sort by position
|
||||
guild_roles.sort(key=lambda role: role['position'])
|
||||
roleids = [r['id'] for r in guild_roles]
|
||||
|
||||
# groups are:
|
||||
# - roles that are hoisted
|
||||
# - "online" and "offline", with "online"
|
||||
# being for people without any roles.
|
||||
|
||||
groups = roleids + ['online', 'offline']
|
||||
await self.set_groups()
|
||||
|
||||
log.debug('{} presences, {} groups',
|
||||
len(guild_presences), len(groups))
|
||||
len(guild_presences),
|
||||
len(self.list.groups))
|
||||
|
||||
group_data = {group: [] for group in groups}
|
||||
self.list.data = {group.gid: [] for group in self.list.groups}
|
||||
|
||||
print('group data', group_data)
|
||||
# first pass: set which presences
|
||||
# go to which groups
|
||||
await self._pass_1(guild_presences)
|
||||
|
||||
def _try_hier(role_id: str, roleids: list):
|
||||
"""Try to fetch a role's position in the hierarchy"""
|
||||
try:
|
||||
return roleids.index(role_id)
|
||||
except ValueError:
|
||||
# the given role isn't on a group
|
||||
# so it doesn't count for our purposes.
|
||||
return 0
|
||||
# second pass: sort each group's members
|
||||
# by the display name
|
||||
await self._sort_groups()
|
||||
|
||||
for presence in guild_presences:
|
||||
# simple group (online or offline)
|
||||
# we'll decide on the best group for the presence later on
|
||||
simple_group = ('offline'
|
||||
if presence['status'] == 'offline'
|
||||
else 'online')
|
||||
@property
|
||||
def items(self) -> list:
|
||||
"""Main items list."""
|
||||
|
||||
# get the best possible role
|
||||
roles = sorted(
|
||||
presence['roles'],
|
||||
key=lambda role_id: _try_hier(role_id, roleids)
|
||||
)
|
||||
# TODO: maybe make this stored in the list
|
||||
# so we don't need to keep regenning?
|
||||
|
||||
try:
|
||||
best_role_id = roles[0]
|
||||
except IndexError:
|
||||
# no hoisted roles exist in the guild, assign
|
||||
# the @everyone role
|
||||
best_role_id = str(self.guild_id)
|
||||
|
||||
print('best role', best_role_id, str(self.guild_id))
|
||||
print('simple group assign', simple_group)
|
||||
|
||||
# if the best role is literally the @everyone role,
|
||||
# this user has no hoisted roles
|
||||
if best_role_id == str(self.guild_id):
|
||||
# this user has no roles, put it on online/offline
|
||||
group_data[simple_group].append(presence)
|
||||
continue
|
||||
|
||||
# this user has a best_role that isn't the
|
||||
# @everyone role, so we'll put them in the respective group
|
||||
try:
|
||||
group_data[best_role_id].append(presence)
|
||||
except KeyError:
|
||||
group_data[simple_group].append(presence)
|
||||
|
||||
# go through each group and sort the resulting members by display name
|
||||
|
||||
members = await self.storage.get_member_data(self.guild_id)
|
||||
member_nicks = {member['user']['id']: member.get('nick')
|
||||
for member in members}
|
||||
|
||||
# now we'll sort each group by their display name
|
||||
# (can be their current nickname OR their username
|
||||
# if no nickname is set)
|
||||
print('pre-sorted group data')
|
||||
pprint.pprint(group_data)
|
||||
|
||||
for _, group_list in group_data.items():
|
||||
def display_name(presence: dict) -> str:
|
||||
uid = presence['user']['id']
|
||||
|
||||
uname = presence['user']['username']
|
||||
nick = member_nicks[uid]
|
||||
|
||||
return nick or uname
|
||||
|
||||
group_list.sort(key=display_name)
|
||||
|
||||
pprint.pprint(group_data)
|
||||
|
||||
self.member_list = {
|
||||
'groups': groups,
|
||||
'data': group_data
|
||||
}
|
||||
|
||||
def get_items(self) -> list:
|
||||
"""Generate the main items list,"""
|
||||
if self.member_list is None:
|
||||
if not self.list:
|
||||
return []
|
||||
|
||||
if self.items:
|
||||
return self.items
|
||||
|
||||
groups = self.member_list['groups']
|
||||
|
||||
res = []
|
||||
for group in groups:
|
||||
members = self.member_list['data'][group]
|
||||
|
||||
# NOTE: maybe use map()?
|
||||
for group, presences in self.list:
|
||||
res.append({
|
||||
'group': {
|
||||
'id': group,
|
||||
'count': len(members),
|
||||
'id': group.gid,
|
||||
'count': len(presences),
|
||||
}
|
||||
})
|
||||
|
||||
for member in members:
|
||||
for presence in presences:
|
||||
res.append({
|
||||
'member': member
|
||||
'member': presence
|
||||
})
|
||||
|
||||
self.items = res
|
||||
return res
|
||||
|
||||
async def sub(self, session_id: str):
|
||||
|
|
@ -230,7 +309,7 @@ class GuildMemberList:
|
|||
# uninitialized) for a future subscriber.
|
||||
|
||||
if not self.state:
|
||||
self.member_list = None
|
||||
self._set_empty_list()
|
||||
|
||||
async def shard_query(self, session_id: str, ranges: list):
|
||||
"""Send a GUILD_MEMBER_LIST_UPDATE event
|
||||
|
|
@ -259,7 +338,7 @@ class GuildMemberList:
|
|||
# subscribe the state to this list
|
||||
await self.sub(session_id)
|
||||
|
||||
# TODO: subscribe shard to 'everyone'
|
||||
# TODO: subscribe shard to the 'everyone' member list
|
||||
# and forward the query to that list
|
||||
|
||||
reply = {
|
||||
|
|
@ -273,9 +352,9 @@ class GuildMemberList:
|
|||
|
||||
'groups': [
|
||||
{
|
||||
'count': len(self.member_list['data'][group]),
|
||||
'id': group
|
||||
} for group in self.member_list['groups']
|
||||
'count': len(presences),
|
||||
'id': group.gid
|
||||
} for group, presences in self.list
|
||||
],
|
||||
|
||||
'ops': [],
|
||||
|
|
@ -288,12 +367,12 @@ class GuildMemberList:
|
|||
if itemcount < 0:
|
||||
continue
|
||||
|
||||
items = self.get_items()
|
||||
# TODO: subscribe user to the slice
|
||||
|
||||
reply['ops'].append({
|
||||
'op': 'SYNC',
|
||||
'range': [start, end],
|
||||
'items': items[start:end],
|
||||
'items': self.items[start:end],
|
||||
})
|
||||
|
||||
# the first GUILD_MEMBER_LIST_UPDATE for a shard
|
||||
|
|
|
|||
Loading…
Reference in New Issue