diff --git a/litecord/dispatcher.py b/litecord/dispatcher.py index 8d38303..96f29d3 100644 --- a/litecord/dispatcher.py +++ b/litecord/dispatcher.py @@ -17,7 +17,7 @@ along with this program. If not, see . """ -from typing import List, Any +from typing import List, Any, Dict from logbook import Logger @@ -57,7 +57,7 @@ class EventDispatcher: 'lazy_guild': LazyGuildDispatcher(self), } - async def action(self, backend_str: str, action: str, key, identifier): + async def action(self, backend_str: str, action: str, key, identifier, *args): """Send an action regarding a key/identifier pair to a backend. Action is usually "sub" or "unsub". @@ -69,13 +69,24 @@ class EventDispatcher: key = backend.KEY_TYPE(key) identifier = backend.VAL_TYPE(identifier) - return await method(key, identifier) + return await method(key, identifier, *args) - async def subscribe(self, backend: str, key: Any, identifier: Any): + async def subscribe(self, backend: str, key: Any, identifier: Any, + flags: Dict[str, Any] = None): """Subscribe a single element to the given backend.""" + flags = flags or {} + log.debug('SUB backend={} key={} <= id={}', backend, key, identifier, backend) + # this is a hacky solution for backwards compatibility between backends + # that implement flags and backends that don't. + + # passing flags to backends that don't implement flags will + # cause errors as expected. + if flags: + return await self.action(backend, 'sub', key, identifier, flags) + return await self.action(backend, 'sub', key, identifier) async def unsubscribe(self, backend: str, key: Any, identifier: Any): @@ -93,24 +104,34 @@ class EventDispatcher: """Alias to unsubscribe().""" return await self.unsubscribe(backend, key, identifier) - async def sub_many(self, backend_str: str, identifier: Any, keys: list): + async def sub_many(self, backend_str: str, identifier: Any, + keys: list, flags: Dict[str, Any] = None): """Subscribe to multiple channels (all in a single backend) at a time. Usually used when connecting to the gateway and the client needs to subscribe to all their guids. """ + flags = flags or {} for key in keys: - await self.subscribe(backend_str, key, identifier) + await self.subscribe(backend_str, key, identifier, flags) async def mass_sub(self, identifier: Any, backends: List[tuple]): """Mass subscribe to many backends at once.""" - for backend_str, keys in backends: - log.debug('subscribing {} to {} keys in backend {}', - identifier, len(keys), backend_str) + for bcall in backends: + backend_str, keys = bcall[0], bcall[1] - await self.sub_many(backend_str, identifier, keys) + if len(bcall) == 2: + flags = {} + elif len(bcall) == 3: + # we have flags + flags = bcall[2] + + log.debug('subscribing {} to {} keys in backend {}, flags: {}', + identifier, len(keys), backend_str, flags) + + await self.sub_many(backend_str, identifier, keys, flags) async def dispatch(self, backend_str: str, key: Any, *args, **kwargs): """Dispatch an event to the backend. diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 365163c..d68f2d6 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -421,7 +421,7 @@ class GatewayWebsocket: return list(filtered) - async def subscribe_all(self): + async def subscribe_all(self, guild_subscriptions: bool): """Subscribe to all guilds, DM channels, and friends. Note: subscribing to channels is already handled @@ -437,15 +437,23 @@ class GatewayWebsocket: # fetch all group dms the user is a member of. gdm_ids = await self.user_storage.get_gdms_internal(user_id) - log.info('subscribing to {} guilds', len(guild_ids)) - log.info('subscribing to {} dms', len(dm_ids)) - log.info('subscribing to {} group dms', len(gdm_ids)) + log.info('subscribing to {} guilds {} dms {} gdms', + len(guild_ids), len(dm_ids), len(gdm_ids)) - await self.ext.dispatcher.mass_sub(user_id, [ - ('guild', guild_ids), + # guild_subscriptions: + # enables dispatching of guild subscription events + # (presence and typing events) + + # we enable processing of guild_subscriptions by adding flags + # when subscribing to the given backend. those are optional. + channels_to_sub = [ + ('guild', guild_ids, + {'presence': guild_subscriptions, 'typing': guild_subscriptions}), ('channel', dm_ids), - ('channel', gdm_ids) - ]) + ('channel', gdm_ids), + ] + + await self.ext.dispatcher.mass_sub(user_id, channels_to_sub) if not self.state.bot: # subscribe to all friends @@ -573,7 +581,7 @@ class GatewayWebsocket: self.ext.state_manager.insert(self.state) await self.update_status(presence) - await self.subscribe_all() + await self.subscribe_all(data.get('guild_subscriptions', True)) await self.dispatch_ready() async def handle_3(self, payload: Dict[str, Any]): diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index 15e4010..443d1e3 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -21,7 +21,7 @@ from typing import Any, List from logbook import Logger -from .dispatcher import DispatcherWithState +from .dispatcher import DispatcherWithFlags from litecord.enums import ChannelType from litecord.utils import index_by_func @@ -48,7 +48,7 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict: return data -class ChannelDispatcher(DispatcherWithState): +class ChannelDispatcher(DispatcherWithFlags): """Main channel Pub/Sub logic.""" KEY_TYPE = int VAL_TYPE = int @@ -84,6 +84,11 @@ class ChannelDispatcher(DispatcherWithState): await self.unsub(channel_id, user_id) continue + # skip typing events for users that don't want it + if event.startswith('TYPING_') and \ + not self.flags_get(channel_id, user_id, 'typing', True): + continue + cur_sess = [] if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \ diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py index dd03ef2..493162a 100644 --- a/litecord/pubsub/dispatcher.py +++ b/litecord/pubsub/dispatcher.py @@ -89,7 +89,7 @@ class Dispatcher: try: await state.ws.dispatch(event, data) res.append(state.session_id) - except: + except Exception: log.exception('error while dispatching') return res @@ -128,3 +128,32 @@ class DispatcherWithState(Dispatcher): async def dispatch(self, key, *args): raise NotImplementedError + + +class DispatcherWithFlags(DispatcherWithState): + """Pub/Sub backend with both a state and a flags store.""" + + def __init__(self, main): + super().__init__(main) + + #: keep flags for subscribers, so for example + # a subscriber could drop all presence events at the + # pubsub level. see gateway's guild_subscriptions field for more + self.flags = defaultdict(dict) + + async def sub(self, key, identifier, flags=None): + """Subscribe a user to the guild.""" + await super().sub(key, identifier) + self.flags[key][identifier] = flags or {} + + async def unsub(self, key, identifier): + """Unsubscribe a user from the guild.""" + await super().unsub(key, identifier) + self.flags[key].pop(identifier) + + def flags_get(self, key, identifier, field: str, default): + """Get a single field from the flags store.""" + # yes, i know its simply an indirection from the main flags store, + # but i'd rather have this than change every call if i ever change + # the structure of the flags store. + return self.flags[key][identifier].get(field, default) diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index 5a419e2..462fb63 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -21,20 +21,21 @@ from typing import Any from logbook import Logger -from .dispatcher import DispatcherWithState +from .dispatcher import DispatcherWithFlags from litecord.permissions import get_permissions log = Logger(__name__) -class GuildDispatcher(DispatcherWithState): +class GuildDispatcher(DispatcherWithFlags): """Guild backend for Pub/Sub""" KEY_TYPE = int VAL_TYPE = int async def _chan_action(self, action: str, - guild_id: int, user_id: int): + guild_id: int, user_id: int, flags=None): """Send an action to all channels of the guild.""" + flags = flags or {} chan_ids = await self.app.storage.get_channel_ids(guild_id) for chan_id in chan_ids: @@ -53,8 +54,22 @@ class GuildDispatcher(DispatcherWithState): log.debug('sending raw action {!r} to chan={}', action, chan_id) + # for now, only sub() has support for flags. + # it is an idea to have flags support for other actions + args = [] + if action == 'sub': + chanflags = dict(flags) + + # channels don't need presence flags + try: + chanflags.pop('presence') + except KeyError: + pass + + args.append(chanflags) + await self.main_dispatcher.action( - 'channel', action, chan_id, user_id + 'channel', action, chan_id, user_id, *args ) async def _chan_call(self, meth: str, guild_id: int, *args): @@ -70,10 +85,10 @@ class GuildDispatcher(DispatcherWithState): meth, chan_id) await method(chan_id, *args) - async def sub(self, guild_id: int, user_id: int): + async def sub(self, guild_id: int, user_id: int, flags=None): """Subscribe a user to the guild.""" - await super().sub(guild_id, user_id) - await self._chan_action('sub', guild_id, user_id) + await super().sub(guild_id, user_id, flags) + await self._chan_action('sub', guild_id, user_id, flags) async def unsub(self, guild_id: int, user_id: int): """Unsubscribe a user from the guild.""" @@ -101,6 +116,15 @@ class GuildDispatcher(DispatcherWithState): await self.unsub(guild_id, user_id) continue + # skip the given subscriber if event starts with PRESENCE_ + # and the flags say they don't want it. + + # note that this does not equate to any unsubscription + # of the channel. + if event.startswith('PRESENCE_') and \ + not self.flags_get(guild_id, user_id, 'presence', True): + continue + # filter the ones that matter states = list(filter( lambda state: func(state.session_id), states @@ -108,6 +132,7 @@ class GuildDispatcher(DispatcherWithState): cur_sess = await self._dispatch_states( states, event, data) + sessions.extend(cur_sess) dispatched += len(cur_sess)