diff --git a/litecord/dispatcher.py b/litecord/dispatcher.py index b2d0ab2..1067df6 100644 --- a/litecord/dispatcher.py +++ b/litecord/dispatcher.py @@ -4,7 +4,7 @@ from typing import Any from logbook import Logger from .pubsub import GuildDispatcher, MemberDispatcher, \ - UserDispatcher + UserDispatcher, ChannelDispatcher log = Logger(__name__) @@ -18,6 +18,7 @@ class EventDispatcher: self.backends = { 'guild': GuildDispatcher(self), 'member': MemberDispatcher(self), + 'channel': ChannelDispatcher(self), 'user': UserDispatcher(self), # TODO: channel, friends @@ -35,10 +36,14 @@ class EventDispatcher: async def subscribe(self, backend: str, key: Any, identifier: Any): """Subscribe a single element to the given backend.""" + log.debug('SUB bacjend={} key={} <= id={}', + backend, key, identifier, backend) return await self.action(backend, 'sub', key, identifier) async def unsubscribe(self, backend: str, key: Any, identifier: Any): """Unsubscribe an element from the given backend.""" + log.debug('UNSUB bacjend={} key={} => id={}', + backend, key, identifier, backend) return await self.action(backend, 'unsub', key, identifier) async def dispatch(self, backend_str: str, key: Any, *args, **kwargs): diff --git a/litecord/pubsub/__init__.py b/litecord/pubsub/__init__.py index 9586ed9..2021ca8 100644 --- a/litecord/pubsub/__init__.py +++ b/litecord/pubsub/__init__.py @@ -1,3 +1,7 @@ from .guild import GuildDispatcher from .member import MemberDispatcher from .user import UserDispatcher +from .channel import ChannelDispatcher + +__all__ = ['GuildDispatcher', 'MemberDispatcher', + 'UserDispatcher', 'ChannelDispatcher'] diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py new file mode 100644 index 0000000..4b722a4 --- /dev/null +++ b/litecord/pubsub/channel.py @@ -0,0 +1,51 @@ +from typing import Any +from collections import defaultdict + +from logbook import Logger + +from .dispatcher import Dispatcher + +log = Logger(__name__) + + +class ChannelDispatcher(Dispatcher): + """Main channel Pub/Sub logic.""" + def __init__(self, main): + super().__init__(main) + + self.channels = defaultdict(set) + + async def sub(self, channel_id: int, user_id: int): + self.channels[channel_id].add(user_id) + + async def unsub(self, channel_id: int, user_id: int): + self.channels[channel_id].discard(user_id) + + async def reset(self, channel_id: int): + self.channels[channel_id] = set() + + async def remove(self, channel_id: int): + try: + self.channels.pop(channel_id) + except KeyError: + pass + + async def dispatch(self, channel_id, + event: str, data: Any): + user_ids = self.channels[channel_id] + dispatched = 0 + + for user_id in set(user_ids): + guild_id = await self.app.storage.guild_from_channel(channel_id) + + if guild_id: + states = self.sm.fetch_states(user_id, guild_id) + else: + # TODO: maybe a fetch_states with guild_id 0 + # to get the shards with id 0 AND the single shards? + states = self.sm.user_states(user_id) + + dispatched += await self._dispatch_states(states, event, data) + + log.info('Dispatched chan={} {!r} to {} states', + channel_id, event, dispatched) diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py index 59cf08c..f15f0fb 100644 --- a/litecord/pubsub/dispatcher.py +++ b/litecord/pubsub/dispatcher.py @@ -5,8 +5,8 @@ log = Logger(__name__) class Dispatcher: """Main dispatcher class.""" - KEY_TYPE = lambda x: x - VAL_TYPE = lambda x: x + KEY_TYPE = lambda _, x: x + VAL_TYPE = lambda _, x: x def __init__(self, main): self.main_dispatcher = main diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index e8d624f..df97bdb 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -17,14 +17,44 @@ class GuildDispatcher(Dispatcher): super().__init__(main) self.guild_buckets = defaultdict(set) + async def _chan_action(self, action: str, guild_id: int, user_id: int): + chan_ids = await self.app.storage.get_channel_ids(guild_id) + + # TODO: check READ_MESSAGE permissions for the user + + for chan_id in chan_ids: + log.debug('sending raw action {!r} to chan={}', + action, chan_id) + + await self.main_dispatcher.action( + 'channel', action, chan_id, user_id + ) + + async def _chan_call(self, meth: str, guild_id: int, *args): + chan_ids = await self.app.storage.get_channel_ids(guild_id) + chan_dispatcher = self.main_dispatcher.backends['channel'] + method = getattr(chan_dispatcher, meth) + + for chan_id in chan_ids: + log.debug('calling {} to chan={}', + meth, chan_id) + await method(chan_id, *args) + async def sub(self, guild_id: int, user_id: int): self.guild_buckets[guild_id].add(user_id) + # when subbing a user to the guild, we should sub them + # to every channel they have access to, in the guild. + + await self._chan_action('sub', guild_id, user_id) + async def unsub(self, guild_id: int, user_id: int): self.guild_buckets[guild_id].discard(user_id) + await self._chan_action('unsub', guild_id, user_id) async def reset(self, guild_id: int): self.guild_buckets[guild_id] = set() + await self._chan_call(guild_id, 'reset') async def remove(self, guild_id: int): try: @@ -32,6 +62,8 @@ class GuildDispatcher(Dispatcher): except KeyError: pass + await self._chan_call(guild_id, 'remove') + async def dispatch(self, guild_id: int, event_name: str, event_payload: Any): user_ids = self.guild_buckets[guild_id] @@ -46,7 +78,7 @@ class GuildDispatcher(Dispatcher): if not states: # user is actually disconnected, - # so we should just unsub it + # so we should just unsub them await self.unsub(guild_id, user_id) continue diff --git a/litecord/storage.py b/litecord/storage.py index 7d9542d..cb89c3c 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -240,6 +240,7 @@ class Storage: return members async def chan_last_message(self, channel_id: int): + """Get the last message ID in a channel.""" return await self.db.fetchval(""" SELECT MAX(id) FROM messages @@ -247,6 +248,11 @@ class Storage: """, channel_id) async def chan_last_message_str(self, channel_id: int) -> int: + """Get the last message ID but in a string. + + Converts to None (not the string "None") when + no last message ID is found. + """ last_msg = await self.chan_last_message(channel_id) return str(last_msg) if last_msg is not None else None @@ -268,7 +274,8 @@ class Storage: 'topic': topic, 'last_message_id': last_msg, }} - elif chan_type == ChannelType.GUILD_VOICE: + + if chan_type == ChannelType.GUILD_VOICE: vrow = await self.db.fetchval(""" SELECT bitrate, user_limit FROM guild_voice_channels WHERE id = $1