diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py index 4ad5ee7..7d40c07 100644 --- a/litecord/pubsub/dispatcher.py +++ b/litecord/pubsub/dispatcher.py @@ -128,3 +128,25 @@ class DispatcherWithState(Dispatcher): async def dispatch(self, key, *args): raise NotImplementedError + + +class DispatcherWithFlags(DispatcherWithState): + """Pub/Sub backend with both a state and a flags store.""" + + def __init__(self, main): + super().__init__(main) + + #: keep flags for subscribers, so for example + # a subscriber could drop all presence events at the + # pubsub level. see gateway's guild_subscriptions field for more + self.flags = defaultdict(dict) + + async def sub(self, key, identifier, flags=None): + """Subscribe a user to the guild.""" + await super().sub(key, identifier) + self.flags[key][identifier] = flags or {} + + async def unsub(self, key, identifier): + """Unsubscribe a user from the guild.""" + await super().unsub(key, identifier) + self.flags[key].pop(identifier) diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index fdfdf2c..20b3bd1 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -22,25 +22,17 @@ from collections import defaultdict from logbook import Logger -from .dispatcher import DispatcherWithState +from .dispatcher import DispatcherWithFlags from litecord.permissions import get_permissions log = Logger(__name__) -class GuildDispatcher(DispatcherWithState): +class GuildDispatcher(DispatcherWithFlags): """Guild backend for Pub/Sub""" KEY_TYPE = int VAL_TYPE = int - def __init__(self, main): - super().__init__(main) - - #: keep flags for subscribers, so for example - # a subscriber could drop all presence events at the - # pubsub level. see gateway's guild_subscriptions field for more - self.flags = defaultdict(dict) - async def _chan_action(self, action: str, guild_id: int, user_id: int, flags=None): """Send an action to all channels of the guild.""" @@ -80,17 +72,14 @@ class GuildDispatcher(DispatcherWithState): meth, chan_id) await method(chan_id, *args) - async def sub(self, guild_id: int, user_id: int, flags = None): + async def sub(self, guild_id: int, user_id: int, flags=None): """Subscribe a user to the guild.""" - await super().sub(guild_id, user_id) - self.flags[guild_id][user_id] = flags or {} - + await super().sub(guild_id, user_id, flags) await self._chan_action('sub', guild_id, user_id, flags) async def unsub(self, guild_id: int, user_id: int): """Unsubscribe a user from the guild.""" await super().unsub(guild_id, user_id) - self.flags[guild_id].pop(user_id) await self._chan_action('unsub', guild_id, user_id) async def dispatch_filter(self, guild_id: int, func,