mirror of https://gitlab.com/litecord/litecord.git
pubsub: add ChannelDispatcher
- pubsub: call ChannelDispatcher from GuildDispatcher when subbing a
user
This commit is contained in:
parent
5372292f0d
commit
37d8114ae2
|
|
@ -4,7 +4,7 @@ from typing import Any
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .pubsub import GuildDispatcher, MemberDispatcher, \
|
from .pubsub import GuildDispatcher, MemberDispatcher, \
|
||||||
UserDispatcher
|
UserDispatcher, ChannelDispatcher
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -18,6 +18,7 @@ class EventDispatcher:
|
||||||
self.backends = {
|
self.backends = {
|
||||||
'guild': GuildDispatcher(self),
|
'guild': GuildDispatcher(self),
|
||||||
'member': MemberDispatcher(self),
|
'member': MemberDispatcher(self),
|
||||||
|
'channel': ChannelDispatcher(self),
|
||||||
'user': UserDispatcher(self),
|
'user': UserDispatcher(self),
|
||||||
|
|
||||||
# TODO: channel, friends
|
# TODO: channel, friends
|
||||||
|
|
@ -35,10 +36,14 @@ class EventDispatcher:
|
||||||
|
|
||||||
async def subscribe(self, backend: str, key: Any, identifier: Any):
|
async def subscribe(self, backend: str, key: Any, identifier: Any):
|
||||||
"""Subscribe a single element to the given backend."""
|
"""Subscribe a single element to the given backend."""
|
||||||
|
log.debug('SUB bacjend={} key={} <= id={}',
|
||||||
|
backend, key, identifier, backend)
|
||||||
return await self.action(backend, 'sub', key, identifier)
|
return await self.action(backend, 'sub', key, identifier)
|
||||||
|
|
||||||
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
||||||
"""Unsubscribe an element from the given backend."""
|
"""Unsubscribe an element from the given backend."""
|
||||||
|
log.debug('UNSUB bacjend={} key={} => id={}',
|
||||||
|
backend, key, identifier, backend)
|
||||||
return await self.action(backend, 'unsub', key, identifier)
|
return await self.action(backend, 'unsub', key, identifier)
|
||||||
|
|
||||||
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
|
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,7 @@
|
||||||
from .guild import GuildDispatcher
|
from .guild import GuildDispatcher
|
||||||
from .member import MemberDispatcher
|
from .member import MemberDispatcher
|
||||||
from .user import UserDispatcher
|
from .user import UserDispatcher
|
||||||
|
from .channel import ChannelDispatcher
|
||||||
|
|
||||||
|
__all__ = ['GuildDispatcher', 'MemberDispatcher',
|
||||||
|
'UserDispatcher', 'ChannelDispatcher']
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
from typing import Any
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from logbook import Logger
|
||||||
|
|
||||||
|
from .dispatcher import Dispatcher
|
||||||
|
|
||||||
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelDispatcher(Dispatcher):
|
||||||
|
"""Main channel Pub/Sub logic."""
|
||||||
|
def __init__(self, main):
|
||||||
|
super().__init__(main)
|
||||||
|
|
||||||
|
self.channels = defaultdict(set)
|
||||||
|
|
||||||
|
async def sub(self, channel_id: int, user_id: int):
|
||||||
|
self.channels[channel_id].add(user_id)
|
||||||
|
|
||||||
|
async def unsub(self, channel_id: int, user_id: int):
|
||||||
|
self.channels[channel_id].discard(user_id)
|
||||||
|
|
||||||
|
async def reset(self, channel_id: int):
|
||||||
|
self.channels[channel_id] = set()
|
||||||
|
|
||||||
|
async def remove(self, channel_id: int):
|
||||||
|
try:
|
||||||
|
self.channels.pop(channel_id)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def dispatch(self, channel_id,
|
||||||
|
event: str, data: Any):
|
||||||
|
user_ids = self.channels[channel_id]
|
||||||
|
dispatched = 0
|
||||||
|
|
||||||
|
for user_id in set(user_ids):
|
||||||
|
guild_id = await self.app.storage.guild_from_channel(channel_id)
|
||||||
|
|
||||||
|
if guild_id:
|
||||||
|
states = self.sm.fetch_states(user_id, guild_id)
|
||||||
|
else:
|
||||||
|
# TODO: maybe a fetch_states with guild_id 0
|
||||||
|
# to get the shards with id 0 AND the single shards?
|
||||||
|
states = self.sm.user_states(user_id)
|
||||||
|
|
||||||
|
dispatched += await self._dispatch_states(states, event, data)
|
||||||
|
|
||||||
|
log.info('Dispatched chan={} {!r} to {} states',
|
||||||
|
channel_id, event, dispatched)
|
||||||
|
|
@ -5,8 +5,8 @@ log = Logger(__name__)
|
||||||
|
|
||||||
class Dispatcher:
|
class Dispatcher:
|
||||||
"""Main dispatcher class."""
|
"""Main dispatcher class."""
|
||||||
KEY_TYPE = lambda x: x
|
KEY_TYPE = lambda _, x: x
|
||||||
VAL_TYPE = lambda x: x
|
VAL_TYPE = lambda _, x: x
|
||||||
|
|
||||||
def __init__(self, main):
|
def __init__(self, main):
|
||||||
self.main_dispatcher = main
|
self.main_dispatcher = main
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,44 @@ class GuildDispatcher(Dispatcher):
|
||||||
super().__init__(main)
|
super().__init__(main)
|
||||||
self.guild_buckets = defaultdict(set)
|
self.guild_buckets = defaultdict(set)
|
||||||
|
|
||||||
|
async def _chan_action(self, action: str, guild_id: int, user_id: int):
|
||||||
|
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||||
|
|
||||||
|
# TODO: check READ_MESSAGE permissions for the user
|
||||||
|
|
||||||
|
for chan_id in chan_ids:
|
||||||
|
log.debug('sending raw action {!r} to chan={}',
|
||||||
|
action, chan_id)
|
||||||
|
|
||||||
|
await self.main_dispatcher.action(
|
||||||
|
'channel', action, chan_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _chan_call(self, meth: str, guild_id: int, *args):
|
||||||
|
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||||
|
chan_dispatcher = self.main_dispatcher.backends['channel']
|
||||||
|
method = getattr(chan_dispatcher, meth)
|
||||||
|
|
||||||
|
for chan_id in chan_ids:
|
||||||
|
log.debug('calling {} to chan={}',
|
||||||
|
meth, chan_id)
|
||||||
|
await method(chan_id, *args)
|
||||||
|
|
||||||
async def sub(self, guild_id: int, user_id: int):
|
async def sub(self, guild_id: int, user_id: int):
|
||||||
self.guild_buckets[guild_id].add(user_id)
|
self.guild_buckets[guild_id].add(user_id)
|
||||||
|
|
||||||
|
# when subbing a user to the guild, we should sub them
|
||||||
|
# to every channel they have access to, in the guild.
|
||||||
|
|
||||||
|
await self._chan_action('sub', guild_id, user_id)
|
||||||
|
|
||||||
async def unsub(self, guild_id: int, user_id: int):
|
async def unsub(self, guild_id: int, user_id: int):
|
||||||
self.guild_buckets[guild_id].discard(user_id)
|
self.guild_buckets[guild_id].discard(user_id)
|
||||||
|
await self._chan_action('unsub', guild_id, user_id)
|
||||||
|
|
||||||
async def reset(self, guild_id: int):
|
async def reset(self, guild_id: int):
|
||||||
self.guild_buckets[guild_id] = set()
|
self.guild_buckets[guild_id] = set()
|
||||||
|
await self._chan_call(guild_id, 'reset')
|
||||||
|
|
||||||
async def remove(self, guild_id: int):
|
async def remove(self, guild_id: int):
|
||||||
try:
|
try:
|
||||||
|
|
@ -32,6 +62,8 @@ class GuildDispatcher(Dispatcher):
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
await self._chan_call(guild_id, 'remove')
|
||||||
|
|
||||||
async def dispatch(self, guild_id: int,
|
async def dispatch(self, guild_id: int,
|
||||||
event_name: str, event_payload: Any):
|
event_name: str, event_payload: Any):
|
||||||
user_ids = self.guild_buckets[guild_id]
|
user_ids = self.guild_buckets[guild_id]
|
||||||
|
|
@ -46,7 +78,7 @@ class GuildDispatcher(Dispatcher):
|
||||||
|
|
||||||
if not states:
|
if not states:
|
||||||
# user is actually disconnected,
|
# user is actually disconnected,
|
||||||
# so we should just unsub it
|
# so we should just unsub them
|
||||||
await self.unsub(guild_id, user_id)
|
await self.unsub(guild_id, user_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -240,6 +240,7 @@ class Storage:
|
||||||
return members
|
return members
|
||||||
|
|
||||||
async def chan_last_message(self, channel_id: int):
|
async def chan_last_message(self, channel_id: int):
|
||||||
|
"""Get the last message ID in a channel."""
|
||||||
return await self.db.fetchval("""
|
return await self.db.fetchval("""
|
||||||
SELECT MAX(id)
|
SELECT MAX(id)
|
||||||
FROM messages
|
FROM messages
|
||||||
|
|
@ -247,6 +248,11 @@ class Storage:
|
||||||
""", channel_id)
|
""", channel_id)
|
||||||
|
|
||||||
async def chan_last_message_str(self, channel_id: int) -> int:
|
async def chan_last_message_str(self, channel_id: int) -> int:
|
||||||
|
"""Get the last message ID but in a string.
|
||||||
|
|
||||||
|
Converts to None (not the string "None") when
|
||||||
|
no last message ID is found.
|
||||||
|
"""
|
||||||
last_msg = await self.chan_last_message(channel_id)
|
last_msg = await self.chan_last_message(channel_id)
|
||||||
return str(last_msg) if last_msg is not None else None
|
return str(last_msg) if last_msg is not None else None
|
||||||
|
|
||||||
|
|
@ -268,7 +274,8 @@ class Storage:
|
||||||
'topic': topic,
|
'topic': topic,
|
||||||
'last_message_id': last_msg,
|
'last_message_id': last_msg,
|
||||||
}}
|
}}
|
||||||
elif chan_type == ChannelType.GUILD_VOICE:
|
|
||||||
|
if chan_type == ChannelType.GUILD_VOICE:
|
||||||
vrow = await self.db.fetchval("""
|
vrow = await self.db.fetchval("""
|
||||||
SELECT bitrate, user_limit FROM guild_voice_channels
|
SELECT bitrate, user_limit FROM guild_voice_channels
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue