mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'impl/gw-guild-subscriptions' into 'master'
Gateway guild subscriptions See merge request litecord/litecord!40
This commit is contained in:
commit
f0dde07418
|
|
@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
"""
|
||||
|
||||
from typing import List, Any
|
||||
from typing import List, Any, Dict
|
||||
|
||||
from logbook import Logger
|
||||
|
||||
|
|
@ -57,7 +57,7 @@ class EventDispatcher:
|
|||
'lazy_guild': LazyGuildDispatcher(self),
|
||||
}
|
||||
|
||||
async def action(self, backend_str: str, action: str, key, identifier):
|
||||
async def action(self, backend_str: str, action: str, key, identifier, *args):
|
||||
"""Send an action regarding a key/identifier pair to a backend.
|
||||
|
||||
Action is usually "sub" or "unsub".
|
||||
|
|
@ -69,13 +69,24 @@ class EventDispatcher:
|
|||
key = backend.KEY_TYPE(key)
|
||||
identifier = backend.VAL_TYPE(identifier)
|
||||
|
||||
return await method(key, identifier)
|
||||
return await method(key, identifier, *args)
|
||||
|
||||
async def subscribe(self, backend: str, key: Any, identifier: Any):
|
||||
async def subscribe(self, backend: str, key: Any, identifier: Any,
|
||||
flags: Dict[str, Any] = None):
|
||||
"""Subscribe a single element to the given backend."""
|
||||
flags = flags or {}
|
||||
|
||||
log.debug('SUB backend={} key={} <= id={}',
|
||||
backend, key, identifier, backend)
|
||||
|
||||
# this is a hacky solution for backwards compatibility between backends
|
||||
# that implement flags and backends that don't.
|
||||
|
||||
# passing flags to backends that don't implement flags will
|
||||
# cause errors as expected.
|
||||
if flags:
|
||||
return await self.action(backend, 'sub', key, identifier, flags)
|
||||
|
||||
return await self.action(backend, 'sub', key, identifier)
|
||||
|
||||
async def unsubscribe(self, backend: str, key: Any, identifier: Any):
|
||||
|
|
@ -93,24 +104,34 @@ class EventDispatcher:
|
|||
"""Alias to unsubscribe()."""
|
||||
return await self.unsubscribe(backend, key, identifier)
|
||||
|
||||
async def sub_many(self, backend_str: str, identifier: Any, keys: list):
|
||||
async def sub_many(self, backend_str: str, identifier: Any,
|
||||
keys: list, flags: Dict[str, Any] = None):
|
||||
"""Subscribe to multiple channels (all in a single backend)
|
||||
at a time.
|
||||
|
||||
Usually used when connecting to the gateway and the client
|
||||
needs to subscribe to all their guids.
|
||||
"""
|
||||
flags = flags or {}
|
||||
for key in keys:
|
||||
await self.subscribe(backend_str, key, identifier)
|
||||
await self.subscribe(backend_str, key, identifier, flags)
|
||||
|
||||
async def mass_sub(self, identifier: Any,
|
||||
backends: List[tuple]):
|
||||
"""Mass subscribe to many backends at once."""
|
||||
for backend_str, keys in backends:
|
||||
log.debug('subscribing {} to {} keys in backend {}',
|
||||
identifier, len(keys), backend_str)
|
||||
for bcall in backends:
|
||||
backend_str, keys = bcall[0], bcall[1]
|
||||
|
||||
await self.sub_many(backend_str, identifier, keys)
|
||||
if len(bcall) == 2:
|
||||
flags = {}
|
||||
elif len(bcall) == 3:
|
||||
# we have flags
|
||||
flags = bcall[2]
|
||||
|
||||
log.debug('subscribing {} to {} keys in backend {}, flags: {}',
|
||||
identifier, len(keys), backend_str, flags)
|
||||
|
||||
await self.sub_many(backend_str, identifier, keys, flags)
|
||||
|
||||
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
|
||||
"""Dispatch an event to the backend.
|
||||
|
|
|
|||
|
|
@ -421,7 +421,7 @@ class GatewayWebsocket:
|
|||
|
||||
return list(filtered)
|
||||
|
||||
async def subscribe_all(self):
|
||||
async def subscribe_all(self, guild_subscriptions: bool):
|
||||
"""Subscribe to all guilds, DM channels, and friends.
|
||||
|
||||
Note: subscribing to channels is already handled
|
||||
|
|
@ -437,15 +437,23 @@ class GatewayWebsocket:
|
|||
# fetch all group dms the user is a member of.
|
||||
gdm_ids = await self.user_storage.get_gdms_internal(user_id)
|
||||
|
||||
log.info('subscribing to {} guilds', len(guild_ids))
|
||||
log.info('subscribing to {} dms', len(dm_ids))
|
||||
log.info('subscribing to {} group dms', len(gdm_ids))
|
||||
log.info('subscribing to {} guilds {} dms {} gdms',
|
||||
len(guild_ids), len(dm_ids), len(gdm_ids))
|
||||
|
||||
await self.ext.dispatcher.mass_sub(user_id, [
|
||||
('guild', guild_ids),
|
||||
# guild_subscriptions:
|
||||
# enables dispatching of guild subscription events
|
||||
# (presence and typing events)
|
||||
|
||||
# we enable processing of guild_subscriptions by adding flags
|
||||
# when subscribing to the given backend. those are optional.
|
||||
channels_to_sub = [
|
||||
('guild', guild_ids,
|
||||
{'presence': guild_subscriptions, 'typing': guild_subscriptions}),
|
||||
('channel', dm_ids),
|
||||
('channel', gdm_ids)
|
||||
])
|
||||
('channel', gdm_ids),
|
||||
]
|
||||
|
||||
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
|
||||
|
||||
if not self.state.bot:
|
||||
# subscribe to all friends
|
||||
|
|
@ -573,7 +581,7 @@ class GatewayWebsocket:
|
|||
self.ext.state_manager.insert(self.state)
|
||||
|
||||
await self.update_status(presence)
|
||||
await self.subscribe_all()
|
||||
await self.subscribe_all(data.get('guild_subscriptions', True))
|
||||
await self.dispatch_ready()
|
||||
|
||||
async def handle_3(self, payload: Dict[str, Any]):
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from typing import Any, List
|
|||
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import DispatcherWithState
|
||||
from .dispatcher import DispatcherWithFlags
|
||||
from litecord.enums import ChannelType
|
||||
from litecord.utils import index_by_func
|
||||
|
||||
|
|
@ -48,7 +48,7 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
|
|||
return data
|
||||
|
||||
|
||||
class ChannelDispatcher(DispatcherWithState):
|
||||
class ChannelDispatcher(DispatcherWithFlags):
|
||||
"""Main channel Pub/Sub logic."""
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
|
@ -84,6 +84,11 @@ class ChannelDispatcher(DispatcherWithState):
|
|||
await self.unsub(channel_id, user_id)
|
||||
continue
|
||||
|
||||
# skip typing events for users that don't want it
|
||||
if event.startswith('TYPING_') and \
|
||||
not self.flags_get(channel_id, user_id, 'typing', True):
|
||||
continue
|
||||
|
||||
cur_sess = []
|
||||
|
||||
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class Dispatcher:
|
|||
try:
|
||||
await state.ws.dispatch(event, data)
|
||||
res.append(state.session_id)
|
||||
except:
|
||||
except Exception:
|
||||
log.exception('error while dispatching')
|
||||
|
||||
return res
|
||||
|
|
@ -128,3 +128,32 @@ class DispatcherWithState(Dispatcher):
|
|||
|
||||
async def dispatch(self, key, *args):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DispatcherWithFlags(DispatcherWithState):
|
||||
"""Pub/Sub backend with both a state and a flags store."""
|
||||
|
||||
def __init__(self, main):
|
||||
super().__init__(main)
|
||||
|
||||
#: keep flags for subscribers, so for example
|
||||
# a subscriber could drop all presence events at the
|
||||
# pubsub level. see gateway's guild_subscriptions field for more
|
||||
self.flags = defaultdict(dict)
|
||||
|
||||
async def sub(self, key, identifier, flags=None):
|
||||
"""Subscribe a user to the guild."""
|
||||
await super().sub(key, identifier)
|
||||
self.flags[key][identifier] = flags or {}
|
||||
|
||||
async def unsub(self, key, identifier):
|
||||
"""Unsubscribe a user from the guild."""
|
||||
await super().unsub(key, identifier)
|
||||
self.flags[key].pop(identifier)
|
||||
|
||||
def flags_get(self, key, identifier, field: str, default):
|
||||
"""Get a single field from the flags store."""
|
||||
# yes, i know its simply an indirection from the main flags store,
|
||||
# but i'd rather have this than change every call if i ever change
|
||||
# the structure of the flags store.
|
||||
return self.flags[key][identifier].get(field, default)
|
||||
|
|
|
|||
|
|
@ -21,20 +21,21 @@ from typing import Any
|
|||
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import DispatcherWithState
|
||||
from .dispatcher import DispatcherWithFlags
|
||||
from litecord.permissions import get_permissions
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
class GuildDispatcher(DispatcherWithState):
|
||||
class GuildDispatcher(DispatcherWithFlags):
|
||||
"""Guild backend for Pub/Sub"""
|
||||
KEY_TYPE = int
|
||||
VAL_TYPE = int
|
||||
|
||||
async def _chan_action(self, action: str,
|
||||
guild_id: int, user_id: int):
|
||||
guild_id: int, user_id: int, flags=None):
|
||||
"""Send an action to all channels of the guild."""
|
||||
flags = flags or {}
|
||||
chan_ids = await self.app.storage.get_channel_ids(guild_id)
|
||||
|
||||
for chan_id in chan_ids:
|
||||
|
|
@ -53,8 +54,22 @@ class GuildDispatcher(DispatcherWithState):
|
|||
log.debug('sending raw action {!r} to chan={}',
|
||||
action, chan_id)
|
||||
|
||||
# for now, only sub() has support for flags.
|
||||
# it is an idea to have flags support for other actions
|
||||
args = []
|
||||
if action == 'sub':
|
||||
chanflags = dict(flags)
|
||||
|
||||
# channels don't need presence flags
|
||||
try:
|
||||
chanflags.pop('presence')
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
args.append(chanflags)
|
||||
|
||||
await self.main_dispatcher.action(
|
||||
'channel', action, chan_id, user_id
|
||||
'channel', action, chan_id, user_id, *args
|
||||
)
|
||||
|
||||
async def _chan_call(self, meth: str, guild_id: int, *args):
|
||||
|
|
@ -70,10 +85,10 @@ class GuildDispatcher(DispatcherWithState):
|
|||
meth, chan_id)
|
||||
await method(chan_id, *args)
|
||||
|
||||
async def sub(self, guild_id: int, user_id: int):
|
||||
async def sub(self, guild_id: int, user_id: int, flags=None):
|
||||
"""Subscribe a user to the guild."""
|
||||
await super().sub(guild_id, user_id)
|
||||
await self._chan_action('sub', guild_id, user_id)
|
||||
await super().sub(guild_id, user_id, flags)
|
||||
await self._chan_action('sub', guild_id, user_id, flags)
|
||||
|
||||
async def unsub(self, guild_id: int, user_id: int):
|
||||
"""Unsubscribe a user from the guild."""
|
||||
|
|
@ -101,6 +116,15 @@ class GuildDispatcher(DispatcherWithState):
|
|||
await self.unsub(guild_id, user_id)
|
||||
continue
|
||||
|
||||
# skip the given subscriber if event starts with PRESENCE_
|
||||
# and the flags say they don't want it.
|
||||
|
||||
# note that this does not equate to any unsubscription
|
||||
# of the channel.
|
||||
if event.startswith('PRESENCE_') and \
|
||||
not self.flags_get(guild_id, user_id, 'presence', True):
|
||||
continue
|
||||
|
||||
# filter the ones that matter
|
||||
states = list(filter(
|
||||
lambda state: func(state.session_id), states
|
||||
|
|
@ -108,6 +132,7 @@ class GuildDispatcher(DispatcherWithState):
|
|||
|
||||
cur_sess = await self._dispatch_states(
|
||||
states, event, data)
|
||||
|
||||
sessions.extend(cur_sess)
|
||||
dispatched += len(cur_sess)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue