mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'impl/gw-guild-subscriptions' into 'master'
Gateway guild subscriptions See merge request litecord/litecord!40
This commit is contained in:
commit
f0dde07418
|
|
@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Any
|
from typing import List, Any, Dict
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
|
|
@ -57,7 +57,7 @@ class EventDispatcher:
|
||||||
'lazy_guild': LazyGuildDispatcher(self),
|
'lazy_guild': LazyGuildDispatcher(self),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def action(self, backend_str: str, action: str, key, identifier):
|
async def action(self, backend_str: str, action: str, key, identifier, *args):
|
||||||
"""Send an action regarding a key/identifier pair to a backend.
|
"""Send an action regarding a key/identifier pair to a backend.
|
||||||
|
|
||||||
Action is usually "sub" or "unsub".
|
Action is usually "sub" or "unsub".
|
||||||
|
|
@ -69,13 +69,24 @@ class EventDispatcher:
|
||||||
key = backend.KEY_TYPE(key)
|
key = backend.KEY_TYPE(key)
|
||||||
identifier = backend.VAL_TYPE(identifier)
|
identifier = backend.VAL_TYPE(identifier)
|
||||||
|
|
||||||
return await method(key, identifier)
|
return await method(key, identifier, *args)
|
||||||
|
|
||||||
async def subscribe(self, backend: str, key: Any, identifier: Any):
|
async def subscribe(self, backend: str, key: Any, identifier: Any,
|
||||||
|
flags: Dict[str, Any] = None):
|
||||||
"""Subscribe a single element to the given backend."""
|
"""Subscribe a single element to the given backend."""
|
||||||
|
flags = flags or {}
|
||||||
|
|
||||||
log.debug('SUB backend={} key={} <= id={}',
|
log.debug('SUB backend={} key={} <= id={}',
|
||||||
backend, key, identifier, backend)
|
backend, key, identifier, backend)
|
||||||
|
|
||||||
|
# this is a hacky solution for backwards compatibility between backends
|
||||||
|
# that implement flags and backends that don't.
|
||||||
|
|
||||||
|
# passing flags to backends that don't implement flags will
|
||||||
|
# cause errors as expected.
|
||||||
|
if flags:
|
||||||
|
return await self.action(backend, 'sub', key, identifier, flags)
|
||||||
|
|
||||||
return await self.action(backend, 'sub', key, identifier)
|
return await self.action(backend, 'sub', key, identifier)
|
||||||
|
|
||||||
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
||||||
|
|
@ -93,24 +104,34 @@ class EventDispatcher:
|
||||||
"""Alias to unsubscribe()."""
|
"""Alias to unsubscribe()."""
|
||||||
return await self.unsubscribe(backend, key, identifier)
|
return await self.unsubscribe(backend, key, identifier)
|
||||||
|
|
||||||
async def sub_many(self, backend_str: str, identifier: Any, keys: list):
|
async def sub_many(self, backend_str: str, identifier: Any,
|
||||||
|
keys: list, flags: Dict[str, Any] = None):
|
||||||
"""Subscribe to multiple channels (all in a single backend)
|
"""Subscribe to multiple channels (all in a single backend)
|
||||||
at a time.
|
at a time.
|
||||||
|
|
||||||
Usually used when connecting to the gateway and the client
|
Usually used when connecting to the gateway and the client
|
||||||
needs to subscribe to all their guids.
|
needs to subscribe to all their guids.
|
||||||
"""
|
"""
|
||||||
|
flags = flags or {}
|
||||||
for key in keys:
|
for key in keys:
|
||||||
await self.subscribe(backend_str, key, identifier)
|
await self.subscribe(backend_str, key, identifier, flags)
|
||||||
|
|
||||||
async def mass_sub(self, identifier: Any,
|
async def mass_sub(self, identifier: Any,
|
||||||
backends: List[tuple]):
|
backends: List[tuple]):
|
||||||
"""Mass subscribe to many backends at once."""
|
"""Mass subscribe to many backends at once."""
|
||||||
for backend_str, keys in backends:
|
for bcall in backends:
|
||||||
log.debug('subscribing {} to {} keys in backend {}',
|
backend_str, keys = bcall[0], bcall[1]
|
||||||
identifier, len(keys), backend_str)
|
|
||||||
|
|
||||||
await self.sub_many(backend_str, identifier, keys)
|
if len(bcall) == 2:
|
||||||
|
flags = {}
|
||||||
|
elif len(bcall) == 3:
|
||||||
|
# we have flags
|
||||||
|
flags = bcall[2]
|
||||||
|
|
||||||
|
log.debug('subscribing {} to {} keys in backend {}, flags: {}',
|
||||||
|
identifier, len(keys), backend_str, flags)
|
||||||
|
|
||||||
|
await self.sub_many(backend_str, identifier, keys, flags)
|
||||||
|
|
||||||
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
|
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
|
||||||
"""Dispatch an event to the backend.
|
"""Dispatch an event to the backend.
|
||||||
|
|
|
||||||
|
|
@ -421,7 +421,7 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
return list(filtered)
|
return list(filtered)
|
||||||
|
|
||||||
async def subscribe_all(self):
|
async def subscribe_all(self, guild_subscriptions: bool):
|
||||||
"""Subscribe to all guilds, DM channels, and friends.
|
"""Subscribe to all guilds, DM channels, and friends.
|
||||||
|
|
||||||
Note: subscribing to channels is already handled
|
Note: subscribing to channels is already handled
|
||||||
|
|
@ -437,15 +437,23 @@ class GatewayWebsocket:
|
||||||
# fetch all group dms the user is a member of.
|
# fetch all group dms the user is a member of.
|
||||||
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
|
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
|
||||||
|
|
||||||
log.info('subscribing to {} guilds', len(guild_ids))
|
log.info('subscribing to {} guilds {} dms {} gdms',
|
||||||
log.info('subscribing to {} dms', len(dm_ids))
|
len(guild_ids), len(dm_ids), len(gdm_ids))
|
||||||
log.info('subscribing to {} group dms', len(gdm_ids))
|
|
||||||
|
|
||||||
await self.ext.dispatcher.mass_sub(user_id, [
|
# guild_subscriptions:
|
||||||
('guild', guild_ids),
|
# enables dispatching of guild subscription events
|
||||||
|
# (presence and typing events)
|
||||||
|
|
||||||
|
# we enable processing of guild_subscriptions by adding flags
|
||||||
|
# when subscribing to the given backend. those are optional.
|
||||||
|
channels_to_sub = [
|
||||||
|
('guild', guild_ids,
|
||||||
|
{'presence': guild_subscriptions, 'typing': guild_subscriptions}),
|
||||||
('channel', dm_ids),
|
('channel', dm_ids),
|
||||||
('channel', gdm_ids)
|
('channel', gdm_ids),
|
||||||
])
|
]
|
||||||
|
|
||||||
|
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
|
||||||
|
|
||||||
if not self.state.bot:
|
if not self.state.bot:
|
||||||
# subscribe to all friends
|
# subscribe to all friends
|
||||||
|
|
@ -573,7 +581,7 @@ class GatewayWebsocket:
|
||||||
self.ext.state_manager.insert(self.state)
|
self.ext.state_manager.insert(self.state)
|
||||||
|
|
||||||
await self.update_status(presence)
|
await self.update_status(presence)
|
||||||
await self.subscribe_all()
|
await self.subscribe_all(data.get('guild_subscriptions', True))
|
||||||
await self.dispatch_ready()
|
await self.dispatch_ready()
|
||||||
|
|
||||||
async def handle_3(self, payload: Dict[str, Any]):
|
async def handle_3(self, payload: Dict[str, Any]):
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from typing import Any, List
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .dispatcher import DispatcherWithState
|
from .dispatcher import DispatcherWithFlags
|
||||||
from litecord.enums import ChannelType
|
from litecord.enums import ChannelType
|
||||||
from litecord.utils import index_by_func
|
from litecord.utils import index_by_func
|
||||||
|
|
||||||
|
|
@ -48,7 +48,7 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ChannelDispatcher(DispatcherWithState):
|
class ChannelDispatcher(DispatcherWithFlags):
|
||||||
"""Main channel Pub/Sub logic."""
|
"""Main channel Pub/Sub logic."""
|
||||||
KEY_TYPE = int
|
KEY_TYPE = int
|
||||||
VAL_TYPE = int
|
VAL_TYPE = int
|
||||||
|
|
@ -84,6 +84,11 @@ class ChannelDispatcher(DispatcherWithState):
|
||||||
await self.unsub(channel_id, user_id)
|
await self.unsub(channel_id, user_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# skip typing events for users that don't want it
|
||||||
|
if event.startswith('TYPING_') and \
|
||||||
|
not self.flags_get(channel_id, user_id, 'typing', True):
|
||||||
|
continue
|
||||||
|
|
||||||
cur_sess = []
|
cur_sess = []
|
||||||
|
|
||||||
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
|
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ class Dispatcher:
|
||||||
try:
|
try:
|
||||||
await state.ws.dispatch(event, data)
|
await state.ws.dispatch(event, data)
|
||||||
res.append(state.session_id)
|
res.append(state.session_id)
|
||||||
except:
|
except Exception:
|
||||||
log.exception('error while dispatching')
|
log.exception('error while dispatching')
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
@ -128,3 +128,32 @@ class DispatcherWithState(Dispatcher):
|
||||||
|
|
||||||
async def dispatch(self, key, *args):
|
async def dispatch(self, key, *args):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class DispatcherWithFlags(DispatcherWithState):
|
||||||
|
"""Pub/Sub backend with both a state and a flags store."""
|
||||||
|
|
||||||
|
def __init__(self, main):
|
||||||
|
super().__init__(main)
|
||||||
|
|
||||||
|
#: keep flags for subscribers, so for example
|
||||||
|
# a subscriber could drop all presence events at the
|
||||||
|
# pubsub level. see gateway's guild_subscriptions field for more
|
||||||
|
self.flags = defaultdict(dict)
|
||||||
|
|
||||||
|
async def sub(self, key, identifier, flags=None):
|
||||||
|
"""Subscribe a user to the guild."""
|
||||||
|
await super().sub(key, identifier)
|
||||||
|
self.flags[key][identifier] = flags or {}
|
||||||
|
|
||||||
|
async def unsub(self, key, identifier):
|
||||||
|
"""Unsubscribe a user from the guild."""
|
||||||
|
await super().unsub(key, identifier)
|
||||||
|
self.flags[key].pop(identifier)
|
||||||
|
|
||||||
|
def flags_get(self, key, identifier, field: str, default):
|
||||||
|
"""Get a single field from the flags store."""
|
||||||
|
# yes, i know its simply an indirection from the main flags store,
|
||||||
|
# but i'd rather have this than change every call if i ever change
|
||||||
|
# the structure of the flags store.
|
||||||
|
return self.flags[key][identifier].get(field, default)
|
||||||
|
|
|
||||||
|
|
@ -21,20 +21,21 @@ from typing import Any
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .dispatcher import DispatcherWithState
|
from .dispatcher import DispatcherWithFlags
|
||||||
from litecord.permissions import get_permissions
|
from litecord.permissions import get_permissions
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class GuildDispatcher(DispatcherWithState):
|
class GuildDispatcher(DispatcherWithFlags):
|
||||||
"""Guild backend for Pub/Sub"""
|
"""Guild backend for Pub/Sub"""
|
||||||
KEY_TYPE = int
|
KEY_TYPE = int
|
||||||
VAL_TYPE = int
|
VAL_TYPE = int
|
||||||
|
|
||||||
async def _chan_action(self, action: str,
|
async def _chan_action(self, action: str,
|
||||||
guild_id: int, user_id: int):
|
guild_id: int, user_id: int, flags=None):
|
||||||
"""Send an action to all channels of the guild."""
|
"""Send an action to all channels of the guild."""
|
||||||
|
flags = flags or {}
|
||||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||||
|
|
||||||
for chan_id in chan_ids:
|
for chan_id in chan_ids:
|
||||||
|
|
@ -53,8 +54,22 @@ class GuildDispatcher(DispatcherWithState):
|
||||||
log.debug('sending raw action {!r} to chan={}',
|
log.debug('sending raw action {!r} to chan={}',
|
||||||
action, chan_id)
|
action, chan_id)
|
||||||
|
|
||||||
|
# for now, only sub() has support for flags.
|
||||||
|
# it is an idea to have flags support for other actions
|
||||||
|
args = []
|
||||||
|
if action == 'sub':
|
||||||
|
chanflags = dict(flags)
|
||||||
|
|
||||||
|
# channels don't need presence flags
|
||||||
|
try:
|
||||||
|
chanflags.pop('presence')
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
args.append(chanflags)
|
||||||
|
|
||||||
await self.main_dispatcher.action(
|
await self.main_dispatcher.action(
|
||||||
'channel', action, chan_id, user_id
|
'channel', action, chan_id, user_id, *args
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _chan_call(self, meth: str, guild_id: int, *args):
|
async def _chan_call(self, meth: str, guild_id: int, *args):
|
||||||
|
|
@ -70,10 +85,10 @@ class GuildDispatcher(DispatcherWithState):
|
||||||
meth, chan_id)
|
meth, chan_id)
|
||||||
await method(chan_id, *args)
|
await method(chan_id, *args)
|
||||||
|
|
||||||
async def sub(self, guild_id: int, user_id: int):
|
async def sub(self, guild_id: int, user_id: int, flags=None):
|
||||||
"""Subscribe a user to the guild."""
|
"""Subscribe a user to the guild."""
|
||||||
await super().sub(guild_id, user_id)
|
await super().sub(guild_id, user_id, flags)
|
||||||
await self._chan_action('sub', guild_id, user_id)
|
await self._chan_action('sub', guild_id, user_id, flags)
|
||||||
|
|
||||||
async def unsub(self, guild_id: int, user_id: int):
|
async def unsub(self, guild_id: int, user_id: int):
|
||||||
"""Unsubscribe a user from the guild."""
|
"""Unsubscribe a user from the guild."""
|
||||||
|
|
@ -101,6 +116,15 @@ class GuildDispatcher(DispatcherWithState):
|
||||||
await self.unsub(guild_id, user_id)
|
await self.unsub(guild_id, user_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# skip the given subscriber if event starts with PRESENCE_
|
||||||
|
# and the flags say they don't want it.
|
||||||
|
|
||||||
|
# note that this does not equate to any unsubscription
|
||||||
|
# of the channel.
|
||||||
|
if event.startswith('PRESENCE_') and \
|
||||||
|
not self.flags_get(guild_id, user_id, 'presence', True):
|
||||||
|
continue
|
||||||
|
|
||||||
# filter the ones that matter
|
# filter the ones that matter
|
||||||
states = list(filter(
|
states = list(filter(
|
||||||
lambda state: func(state.session_id), states
|
lambda state: func(state.session_id), states
|
||||||
|
|
@ -108,6 +132,7 @@ class GuildDispatcher(DispatcherWithState):
|
||||||
|
|
||||||
cur_sess = await self._dispatch_states(
|
cur_sess = await self._dispatch_states(
|
||||||
states, event, data)
|
states, event, data)
|
||||||
|
|
||||||
sessions.extend(cur_sess)
|
sessions.extend(cur_sess)
|
||||||
dispatched += len(cur_sess)
|
dispatched += len(cur_sess)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue