Merge branch 'impl/gw-guild-subscriptions' into 'master'

Gateway guild subscriptions

See merge request litecord/litecord!40
This commit is contained in:
Luna 2019-07-20 21:29:46 +00:00
commit f0dde07418
5 changed files with 117 additions and 29 deletions

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List, Any from typing import List, Any, Dict
from logbook import Logger from logbook import Logger
@ -57,7 +57,7 @@ class EventDispatcher:
'lazy_guild': LazyGuildDispatcher(self), 'lazy_guild': LazyGuildDispatcher(self),
} }
async def action(self, backend_str: str, action: str, key, identifier): async def action(self, backend_str: str, action: str, key, identifier, *args):
"""Send an action regarding a key/identifier pair to a backend. """Send an action regarding a key/identifier pair to a backend.
Action is usually "sub" or "unsub". Action is usually "sub" or "unsub".
@ -69,13 +69,24 @@ class EventDispatcher:
key = backend.KEY_TYPE(key) key = backend.KEY_TYPE(key)
identifier = backend.VAL_TYPE(identifier) identifier = backend.VAL_TYPE(identifier)
return await method(key, identifier) return await method(key, identifier, *args)
async def subscribe(self, backend: str, key: Any, identifier: Any): async def subscribe(self, backend: str, key: Any, identifier: Any,
flags: Dict[str, Any] = None):
"""Subscribe a single element to the given backend.""" """Subscribe a single element to the given backend."""
flags = flags or {}
log.debug('SUB backend={} key={} <= id={}', log.debug('SUB backend={} key={} <= id={}',
backend, key, identifier, backend) backend, key, identifier, backend)
# this is a hacky solution for backwards compatibility between backends
# that implement flags and backends that don't.
# passing flags to backends that don't implement flags will
# cause errors as expected.
if flags:
return await self.action(backend, 'sub', key, identifier, flags)
return await self.action(backend, 'sub', key, identifier) return await self.action(backend, 'sub', key, identifier)
async def unsubscribe(self, backend: str, key: Any, identifier: Any): async def unsubscribe(self, backend: str, key: Any, identifier: Any):
@ -93,24 +104,34 @@ class EventDispatcher:
"""Alias to unsubscribe().""" """Alias to unsubscribe()."""
return await self.unsubscribe(backend, key, identifier) return await self.unsubscribe(backend, key, identifier)
async def sub_many(self, backend_str: str, identifier: Any, keys: list): async def sub_many(self, backend_str: str, identifier: Any,
keys: list, flags: Dict[str, Any] = None):
"""Subscribe to multiple channels (all in a single backend) """Subscribe to multiple channels (all in a single backend)
at a time. at a time.
Usually used when connecting to the gateway and the client Usually used when connecting to the gateway and the client
needs to subscribe to all their guids. needs to subscribe to all their guids.
""" """
flags = flags or {}
for key in keys: for key in keys:
await self.subscribe(backend_str, key, identifier) await self.subscribe(backend_str, key, identifier, flags)
async def mass_sub(self, identifier: Any, async def mass_sub(self, identifier: Any,
backends: List[tuple]): backends: List[tuple]):
"""Mass subscribe to many backends at once.""" """Mass subscribe to many backends at once."""
for backend_str, keys in backends: for bcall in backends:
log.debug('subscribing {} to {} keys in backend {}', backend_str, keys = bcall[0], bcall[1]
identifier, len(keys), backend_str)
await self.sub_many(backend_str, identifier, keys) if len(bcall) == 2:
flags = {}
elif len(bcall) == 3:
# we have flags
flags = bcall[2]
log.debug('subscribing {} to {} keys in backend {}, flags: {}',
identifier, len(keys), backend_str, flags)
await self.sub_many(backend_str, identifier, keys, flags)
async def dispatch(self, backend_str: str, key: Any, *args, **kwargs): async def dispatch(self, backend_str: str, key: Any, *args, **kwargs):
"""Dispatch an event to the backend. """Dispatch an event to the backend.

View File

@ -421,7 +421,7 @@ class GatewayWebsocket:
return list(filtered) return list(filtered)
async def subscribe_all(self): async def subscribe_all(self, guild_subscriptions: bool):
"""Subscribe to all guilds, DM channels, and friends. """Subscribe to all guilds, DM channels, and friends.
Note: subscribing to channels is already handled Note: subscribing to channels is already handled
@ -437,15 +437,23 @@ class GatewayWebsocket:
# fetch all group dms the user is a member of. # fetch all group dms the user is a member of.
gdm_ids = await self.user_storage.get_gdms_internal(user_id) gdm_ids = await self.user_storage.get_gdms_internal(user_id)
log.info('subscribing to {} guilds', len(guild_ids)) log.info('subscribing to {} guilds {} dms {} gdms',
log.info('subscribing to {} dms', len(dm_ids)) len(guild_ids), len(dm_ids), len(gdm_ids))
log.info('subscribing to {} group dms', len(gdm_ids))
await self.ext.dispatcher.mass_sub(user_id, [ # guild_subscriptions:
('guild', guild_ids), # enables dispatching of guild subscription events
# (presence and typing events)
# we enable processing of guild_subscriptions by adding flags
# when subscribing to the given backend. those are optional.
channels_to_sub = [
('guild', guild_ids,
{'presence': guild_subscriptions, 'typing': guild_subscriptions}),
('channel', dm_ids), ('channel', dm_ids),
('channel', gdm_ids) ('channel', gdm_ids),
]) ]
await self.ext.dispatcher.mass_sub(user_id, channels_to_sub)
if not self.state.bot: if not self.state.bot:
# subscribe to all friends # subscribe to all friends
@ -573,7 +581,7 @@ class GatewayWebsocket:
self.ext.state_manager.insert(self.state) self.ext.state_manager.insert(self.state)
await self.update_status(presence) await self.update_status(presence)
await self.subscribe_all() await self.subscribe_all(data.get('guild_subscriptions', True))
await self.dispatch_ready() await self.dispatch_ready()
async def handle_3(self, payload: Dict[str, Any]): async def handle_3(self, payload: Dict[str, Any]):

View File

@ -21,7 +21,7 @@ from typing import Any, List
from logbook import Logger from logbook import Logger
from .dispatcher import DispatcherWithState from .dispatcher import DispatcherWithFlags
from litecord.enums import ChannelType from litecord.enums import ChannelType
from litecord.utils import index_by_func from litecord.utils import index_by_func
@ -48,7 +48,7 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
return data return data
class ChannelDispatcher(DispatcherWithState): class ChannelDispatcher(DispatcherWithFlags):
"""Main channel Pub/Sub logic.""" """Main channel Pub/Sub logic."""
KEY_TYPE = int KEY_TYPE = int
VAL_TYPE = int VAL_TYPE = int
@ -84,6 +84,11 @@ class ChannelDispatcher(DispatcherWithState):
await self.unsub(channel_id, user_id) await self.unsub(channel_id, user_id)
continue continue
# skip typing events for users that don't want it
if event.startswith('TYPING_') and \
not self.flags_get(channel_id, user_id, 'typing', True):
continue
cur_sess = [] cur_sess = []
if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \ if event in ('CHANNEL_CREATE', 'CHANNEL_UPDATE') \

View File

@ -89,7 +89,7 @@ class Dispatcher:
try: try:
await state.ws.dispatch(event, data) await state.ws.dispatch(event, data)
res.append(state.session_id) res.append(state.session_id)
except: except Exception:
log.exception('error while dispatching') log.exception('error while dispatching')
return res return res
@ -128,3 +128,32 @@ class DispatcherWithState(Dispatcher):
async def dispatch(self, key, *args): async def dispatch(self, key, *args):
raise NotImplementedError raise NotImplementedError
class DispatcherWithFlags(DispatcherWithState):
"""Pub/Sub backend with both a state and a flags store."""
def __init__(self, main):
super().__init__(main)
#: keep flags for subscribers, so for example
# a subscriber could drop all presence events at the
# pubsub level. see gateway's guild_subscriptions field for more
self.flags = defaultdict(dict)
async def sub(self, key, identifier, flags=None):
"""Subscribe a user to the guild."""
await super().sub(key, identifier)
self.flags[key][identifier] = flags or {}
async def unsub(self, key, identifier):
"""Unsubscribe a user from the guild."""
await super().unsub(key, identifier)
self.flags[key].pop(identifier)
def flags_get(self, key, identifier, field: str, default):
"""Get a single field from the flags store."""
# yes, i know its simply an indirection from the main flags store,
# but i'd rather have this than change every call if i ever change
# the structure of the flags store.
return self.flags[key][identifier].get(field, default)

View File

@ -21,20 +21,21 @@ from typing import Any
from logbook import Logger from logbook import Logger
from .dispatcher import DispatcherWithState from .dispatcher import DispatcherWithFlags
from litecord.permissions import get_permissions from litecord.permissions import get_permissions
log = Logger(__name__) log = Logger(__name__)
class GuildDispatcher(DispatcherWithState): class GuildDispatcher(DispatcherWithFlags):
"""Guild backend for Pub/Sub""" """Guild backend for Pub/Sub"""
KEY_TYPE = int KEY_TYPE = int
VAL_TYPE = int VAL_TYPE = int
async def _chan_action(self, action: str, async def _chan_action(self, action: str,
guild_id: int, user_id: int): guild_id: int, user_id: int, flags=None):
"""Send an action to all channels of the guild.""" """Send an action to all channels of the guild."""
flags = flags or {}
chan_ids = await self.app.storage.get_channel_ids(guild_id) chan_ids = await self.app.storage.get_channel_ids(guild_id)
for chan_id in chan_ids: for chan_id in chan_ids:
@ -53,8 +54,22 @@ class GuildDispatcher(DispatcherWithState):
log.debug('sending raw action {!r} to chan={}', log.debug('sending raw action {!r} to chan={}',
action, chan_id) action, chan_id)
# for now, only sub() has support for flags.
# it is an idea to have flags support for other actions
args = []
if action == 'sub':
chanflags = dict(flags)
# channels don't need presence flags
try:
chanflags.pop('presence')
except KeyError:
pass
args.append(chanflags)
await self.main_dispatcher.action( await self.main_dispatcher.action(
'channel', action, chan_id, user_id 'channel', action, chan_id, user_id, *args
) )
async def _chan_call(self, meth: str, guild_id: int, *args): async def _chan_call(self, meth: str, guild_id: int, *args):
@ -70,10 +85,10 @@ class GuildDispatcher(DispatcherWithState):
meth, chan_id) meth, chan_id)
await method(chan_id, *args) await method(chan_id, *args)
async def sub(self, guild_id: int, user_id: int): async def sub(self, guild_id: int, user_id: int, flags=None):
"""Subscribe a user to the guild.""" """Subscribe a user to the guild."""
await super().sub(guild_id, user_id) await super().sub(guild_id, user_id, flags)
await self._chan_action('sub', guild_id, user_id) await self._chan_action('sub', guild_id, user_id, flags)
async def unsub(self, guild_id: int, user_id: int): async def unsub(self, guild_id: int, user_id: int):
"""Unsubscribe a user from the guild.""" """Unsubscribe a user from the guild."""
@ -101,6 +116,15 @@ class GuildDispatcher(DispatcherWithState):
await self.unsub(guild_id, user_id) await self.unsub(guild_id, user_id)
continue continue
# skip the given subscriber if event starts with PRESENCE_
# and the flags say they don't want it.
# note that this does not equate to any unsubscription
# of the channel.
if event.startswith('PRESENCE_') and \
not self.flags_get(guild_id, user_id, 'presence', True):
continue
# filter the ones that matter # filter the ones that matter
states = list(filter( states = list(filter(
lambda state: func(state.session_id), states lambda state: func(state.session_id), states
@ -108,6 +132,7 @@ class GuildDispatcher(DispatcherWithState):
cur_sess = await self._dispatch_states( cur_sess = await self._dispatch_states(
states, event, data) states, event, data)
sessions.extend(cur_sess) sessions.extend(cur_sess)
dispatched += len(cur_sess) dispatched += len(cur_sess)