diff --git a/litecord/permissions.py b/litecord/permissions.py index f5fa63c..b63fa98 100644 --- a/litecord/permissions.py +++ b/litecord/permissions.py @@ -65,6 +65,20 @@ class Permissions(ctypes.Union): ALL_PERMISSIONS = Permissions(0b01111111111101111111110111111111) +async def get_role_perms(guild_id, role_id, storage=None) -> Permissions: + """Get the raw :class:`Permissions` object for a role.""" + if not storage: + storage = app.storage + + perms = await storage.db.fetchval(""" + SELECT permissions + FROM roles + WHERE guild_id = $1 AND id = $2 + """, guild_id, role_id) + + return Permissions(perms) + + async def base_permissions(member_id, guild_id, storage=None) -> Permissions: """Compute the base permissions for a given user. @@ -89,13 +103,7 @@ async def base_permissions(member_id, guild_id, storage=None) -> Permissions: return ALL_PERMISSIONS # get permissions for @everyone - everyone_perms = await storage.db.fetchval(""" - SELECT permissions - FROM roles - WHERE guild_id = $1 - """, guild_id) - - permissions = Permissions(everyone_perms) + permissions = await get_role_perms(guild_id, guild_id, storage) role_ids = await storage.db.fetch(""" SELECT role_id @@ -149,6 +157,26 @@ def overwrite_find_mix(perms: Permissions, overwrites: dict, return perms +async def role_permissions(guild_id: int, role_id: int, + channel_id: int, storage=None) -> Permissions: + """Get the permissions for a role, in relation to a channel""" + if not storage: + storage = app.storage + + perms = await get_role_perms(guild_id, role_id, storage) + + overwrite = await storage.db.fetchrow(""" + SELECT allow, deny + FROM channel_overwrites + WHERE channel_id = $1 AND target_type = $2 AND target_role = $3 + """, channel_id, 1, role_id) + + if overwrite: + perms = overwrite_mix(perms, overwrite) + + return perms + + async def compute_overwrites(base_perms, user_id, channel_id: int, guild_id: int = None, storage=None): """Compute the permissions in the context of a channel.""" diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index ba19276..5d865e4 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -10,7 +10,7 @@ from logbook import Logger from litecord.pubsub.dispatcher import Dispatcher from litecord.permissions import ( - Permissions, overwrite_find_mix, get_permissions + Permissions, overwrite_find_mix, get_permissions, role_permissions ) log = Logger(__name__) @@ -90,16 +90,15 @@ class GuildMemberList: # a really long chain of classes to get # to the storage instance... - main = main_lg.main_dispatcher - self.storage = main.app.storage - self.presence = main.app.presence - self.state_man = main.app.state_manager + self.main = main_lg + self.storage = self.main.app.storage + self.presence = self.main.app.presence + self.state_man = self.main.app.state_manager self.list = MemberList(None, None, None, None) - #: holds the state of subscribed shards - # to this channels' member list - self.state = set() + #: {session_id: set[list]} + self.state = defaultdict(set) def _set_empty_list(self): self.list = MemberList(None, None, None, None) @@ -296,14 +295,16 @@ class GuildMemberList: return res - async def sub(self, session_id: str): + async def sub(self, _session_id: str): """Subscribe a shard to the member list.""" await self._init_check() - self.state.add(session_id) async def unsub(self, session_id: str): """Unsubscribe a shard from the member list""" - self.state.discard(session_id) + try: + self.state.pop(session_id) + except KeyError: + pass # once we reach 0 subscribers, # we drop the current member list we have (for memory) @@ -327,6 +328,29 @@ class GuildMemberList: ranges of the list that we want. """ + # a guild list with a channel id of the guild + # represents the 'everyone' global list. + list_id = ('everyone' + if self.channel_id == self.guild_id + else str(self.channel_id)) + + # if everyone can read the channel, + # we direct the request to the 'everyone' gml instance + # instead of the current one. + everyone_perms = await role_permissions( + self.guild_id, + self.guild_id, + self.channel_id, + storage=self.storage + ) + + if everyone_perms.bits.read_messages and list_id != 'everyone': + everyone_gml = await self.main.get_gml(self.guild_id) + + return await everyone_gml.shard_query( + session_id, ranges + ) + await self._init_check() # make sure this is a sane state @@ -335,22 +359,9 @@ class GuildMemberList: await self.unsub(session_id) return - # since this is a sane state AND - # trying to query, we automatically - # subscribe the state to this list - await self.sub(session_id) - - # TODO: subscribe shard to the 'everyone' member list - # and forward the query to that list - reply = { 'guild_id': str(self.guild_id), - - # TODO: everyone for channels without overrides - # channel_id for channels WITH overrides. - - 'id': 'everyone', - # 'id': str(self.channel_id), + 'id': list_id, 'groups': [ { @@ -386,22 +397,17 @@ class GuildMemberList: return list(self.state) async def dispatch(self, event: str, data: Any): - """The dispatch() method here, instead of being - about dispatching a single event to the subscribed - users and forgetting about it, is about storing - the actual member list information so that we - can generate the respective events to the users. + """Modify the member list and dispatch the respective + events to subscribed shards. GuildMemberList stores the current guilds' list - in its :attr:`GuildMemberList.member_list` attribute, + in its :attr:`GuildMemberList.list` attribute, with that attribute being modified via different calls to :meth:`GuildMemberList.dispatch` """ - if self.member_list is None: - # if the list is currently uninitialized, - # no subscribers actually happened, so - # we can safely drop the incoming event. + # if no subscribers, drop event + if not self.list: return @@ -436,6 +442,11 @@ class LazyGuildDispatcher(Dispatcher): channel_id ) + # if we don't find a guild, we just + # set it the same as the channel. + if not guild_id: + guild_id = channel_id + gml = GuildMemberList(guild_id, channel_id, self) self.state[channel_id] = gml self.guild_map[guild_id].append(channel_id) diff --git a/schema.sql b/schema.sql index 9f9bda9..11b2caf 100644 --- a/schema.sql +++ b/schema.sql @@ -331,7 +331,7 @@ CREATE TABLE IF NOT EXISTS channel_overwrites ( channel_id bigint REFERENCES channels (id) ON DELETE CASCADE, -- target_type = 0 -> use target_user - -- target_type = 1 -> user target_role + -- target_type = 1 -> use target_role -- discord already has overwrite.type = 'role' | 'member' -- so this allows us to be more compliant with the API target_type integer default null,