Merge branch 'impl/gateway-intents' into 'master'

Add Gateway Intents support

Closes #74

See merge request litecord/litecord!74
This commit is contained in:
luna 2021-07-14 19:05:00 +00:00
commit d3806353b6
8 changed files with 131 additions and 89 deletions

View File

@ -19,7 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import inspect import inspect
from typing import List, Any from typing import List, Any
from enum import Enum, IntEnum from enum import Enum, IntEnum, IntFlag
class EasyEnum(Enum): class EasyEnum(Enum):
@ -66,7 +66,7 @@ class Flags:
for attr, val in cls._attrs: for attr, val in cls._attrs:
has_attr = (value & val) == val has_attr = (value & val) == val
# set attributes dynamically # set attributes dynamically
setattr(res, f"is_{attr}", has_attr) setattr(res, f"is_{attr.lower()}", has_attr)
return res return res
@ -247,3 +247,76 @@ class Feature(EasyEnum):
# unknown # unknown
commerce = "COMMERCE" commerce = "COMMERCE"
news = "NEWS" news = "NEWS"
class Intents(IntFlag):
GUILDS = 1 << 0
GUILD_MEMBERS = 1 << 1
GUILD_BANS = 1 << 2
GUILD_EMOJIS = 1 << 3
GUILD_INTEGRATIONS = 1 << 4
GUILD_WEBHOOKS = 1 << 5
GUILD_INVITES = 1 << 6
GUILD_VOICE_STATES = 1 << 7
GUILD_PRESENCES = 1 << 8
GUILD_MESSAGES = 1 << 9
GUILD_MESSAGE_REACTIONS = 1 << 10
GUILD_MESSAGE_TYPING = 1 << 11
DIRECT_MESSAGES = 1 << 12
DIRECT_MESSAGE_REACTIONS = 1 << 13
DIRECT_MESSAGE_TYPING = 1 << 14
@classmethod
def default(cls):
return cls(-1)
EVENTS_TO_INTENTS = {
"GUILD_CREATE": Intents.GUILDS,
"GUILD_UPDATE": Intents.GUILDS,
"GUILD_DELETE": Intents.GUILDS,
"GUILD_ROLE_CREATE": Intents.GUILDS,
"GUILD_ROLE_UPDATE": Intents.GUILDS,
"GUILD_ROLE_DELETE": Intents.GUILDS,
"CHANNEL_CREATE": Intents.GUILDS,
"CHANNEL_UPDATE": Intents.GUILDS,
"CHANNEL_DELETE": Intents.GUILDS,
"CHANNEL_PINS_UPDATE": Intents.GUILDS,
# --- threads not supported --
"THREAD_CREATE": Intents.GUILDS,
"THREAD_UPDATE": Intents.GUILDS,
"THREAD_DELETE": Intents.GUILDS,
"THREAD_LIST_SYNC": Intents.GUILDS,
"THREAD_MEMBER_UPDATE": Intents.GUILDS,
"THREAD_MEMBERS_UPDATE": Intents.GUILDS,
# --- stages not supported --
"STAGE_INSTANCE_CREATE": Intents.GUILDS,
"STAGE_INSTANCE_UPDATE": Intents.GUILDS,
"STAGE_INSTANCE_DELETE": Intents.GUILDS,
"GUILD_MEMBER_ADD": Intents.GUILD_MEMBERS,
"GUILD_MEMBER_UPDATE": Intents.GUILD_MEMBERS,
"GUILD_MEMBER_REMOVE": Intents.GUILD_MEMBERS,
# --- threads not supported --
"THREAD_MEMBERS_UPDATE ": Intents.GUILD_MEMBERS,
"GUILD_BAN_ADD": Intents.GUILD_BANS,
"GUILD_BAN_REMOVE": Intents.GUILD_BANS,
"GUILD_EMOJIS_UPDATE": Intents.GUILD_EMOJIS,
"GUILD_INTEGRATIONS_UPDATE": Intents.GUILD_INTEGRATIONS,
"INTEGRATION_CREATE": Intents.GUILD_INTEGRATIONS,
"INTEGRATION_UPDATE": Intents.GUILD_INTEGRATIONS,
"INTEGRATION_DELETE": Intents.GUILD_INTEGRATIONS,
"WEBHOOKS_UPDATE": Intents.GUILD_WEBHOOKS,
"INVITE_CREATE": Intents.GUILD_INVITES,
"INVITE_DELETE": Intents.GUILD_INVITES,
"VOICE_STATE_UPDATE": Intents.GUILD_VOICE_STATES,
"PRESENCE_UPDATE": Intents.GUILD_PRESENCES,
"MESSAGE_CREATE": Intents.GUILD_MESSAGES,
"MESSAGE_UPDATE": Intents.GUILD_MESSAGES,
"MESSAGE_DELETE": Intents.GUILD_MESSAGES,
"MESSAGE_DELETE_BULK": Intents.GUILD_MESSAGES,
"MESSAGE_REACTION_ADD": Intents.GUILD_MESSAGE_REACTIONS,
"MESSAGE_REACTION_REMOVE": Intents.GUILD_MESSAGE_REACTIONS,
"MESSAGE_REACTION_REMOVE_ALL": Intents.GUILD_MESSAGE_REACTIONS,
"MESSAGE_REACTION_REMOVE_EMOJI": Intents.GUILD_MESSAGE_REACTIONS,
"TYPING_START": Intents.GUILD_MESSAGE_TYPING,
}

View File

@ -21,9 +21,9 @@ from typing import Dict
from logbook import Logger from logbook import Logger
from litecord.gateway.errors import DecodeError from litecord.gateway.errors import DecodeError
from litecord.schemas import LitecordValidator from litecord.schemas import LitecordValidator
from litecord.enums import Intents
log = Logger(__name__) log = Logger(__name__)
@ -64,6 +64,7 @@ IDENTIFY_SCHEMA = {
"large_threshold": {"type": "number", "required": False}, "large_threshold": {"type": "number", "required": False},
"shard": {"type": "list", "required": False}, "shard": {"type": "list", "required": False},
"presence": {"type": "dict", "required": False}, "presence": {"type": "dict", "required": False},
"intents": {"coerce": Intents, "required": False},
}, },
} }
}, },

View File

@ -22,6 +22,7 @@ import os
from typing import Optional from typing import Optional
from litecord.presence import BasePresence from litecord.presence import BasePresence
from litecord.enums import Intents
def gen_session_id() -> str: def gen_session_id() -> str:
@ -93,6 +94,7 @@ class GatewayState:
self.compress: bool = kwargs.get("compress") or False self.compress: bool = kwargs.get("compress") or False
self.large: int = kwargs.get("large") or 50 self.large: int = kwargs.get("large") or 50
self.intents: Intents = kwargs["intents"]
def __bool__(self): def __bool__(self):
"""Return if the given state is a valid state to be used.""" """Return if the given state is a valid state to be used."""

View File

@ -28,6 +28,7 @@ from logbook import Logger
from litecord.gateway.state import GatewayState from litecord.gateway.state import GatewayState
from litecord.gateway.opcodes import OP from litecord.gateway.opcodes import OP
from litecord.enums import Intents
log = Logger(__name__) log = Logger(__name__)
@ -174,6 +175,7 @@ class StateManager:
"game": None, "game": None,
"since": 0, "since": 0,
}, },
intents=Intents.default(),
) )
states.append(dummy_state) states.append(dummy_state)

View File

@ -31,7 +31,7 @@ from logbook import Logger
from quart import current_app as app from quart import current_app as app
from litecord.auth import raw_token_check from litecord.auth import raw_token_check
from litecord.enums import RelationshipType, ChannelType, ActivityType from litecord.enums import RelationshipType, ChannelType, ActivityType, Intents
from litecord.utils import ( from litecord.utils import (
task_wrapper, task_wrapper,
yield_chunks, yield_chunks,
@ -56,8 +56,6 @@ from litecord.gateway.errors import (
) )
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
from litecord.gateway.utils import WebsocketFileHandler from litecord.gateway.utils import WebsocketFileHandler
from litecord.pubsub.guild import GuildFlags
from litecord.pubsub.channel import ChannelFlags
from litecord.gateway.schemas import ( from litecord.gateway.schemas import (
validate, validate,
IDENTIFY_SCHEMA, IDENTIFY_SCHEMA,
@ -101,6 +99,20 @@ def _complete_users_list(user_id: str, base_ready, user_ready) -> dict:
return ready return ready
def calculate_intents(data) -> Intents:
intents_int = data.get("intents")
guild_subscriptions = data.get("guild_subscriptions")
if guild_subscriptions is False and intents_int is None:
intents_int = Intents(0)
intents_int |= Intents.GUILD_MESSAGE_TYPING
intents_int |= Intents.DIRECT_MESSAGE_TYPING
intents_int = ~intents_int
elif intents_int is None:
intents_int = Intents.default()
return Intents(intents_int)
class GatewayWebsocket: class GatewayWebsocket:
"""Main gateway websocket logic.""" """Main gateway websocket logic."""
@ -460,7 +472,7 @@ class GatewayWebsocket:
return list(filtered) return list(filtered)
async def subscribe_all(self, guild_subscriptions: bool): async def subscribe_all(self):
"""Subscribe to all guilds, DM channels, and friends. """Subscribe to all guilds, DM channels, and friends.
Note: subscribing to channels is already handled Note: subscribing to channels is already handled
@ -494,11 +506,7 @@ class GatewayWebsocket:
channel_ids: List[int] = [] channel_ids: List[int] = []
for guild_id in guild_ids: for guild_id in guild_ids:
await app.dispatcher.guild.sub_with_flags( await app.dispatcher.guild.sub(guild_id, session_id)
guild_id,
session_id,
GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions),
)
# instead of calculating which channels to subscribe to # instead of calculating which channels to subscribe to
# inside guild dispatcher, we calculate them in here, so that # inside guild dispatcher, we calculate them in here, so that
@ -515,9 +523,7 @@ class GatewayWebsocket:
log.info("subscribing to {} guild channels", len(channel_ids)) log.info("subscribing to {} guild channels", len(channel_ids))
for channel_id in channel_ids: for channel_id in channel_ids:
await app.dispatcher.channel.sub_with_flags( await app.dispatcher.channel.sub(channel_id, session_id)
channel_id, session_id, ChannelFlags(typing=guild_subscriptions)
)
for dm_id in dm_ids: for dm_id in dm_ids:
await app.dispatcher.channel.sub(dm_id, session_id) await app.dispatcher.channel.sub(dm_id, session_id)
@ -668,6 +674,8 @@ class GatewayWebsocket:
shard = data.get("shard", [0, 1]) shard = data.get("shard", [0, 1])
presence = data.get("presence") or {} presence = data.get("presence") or {}
intents = calculate_intents(data)
try: try:
user_id = await raw_token_check(token, self.app.db) user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden): except (Unauthorized, Forbidden):
@ -693,6 +701,7 @@ class GatewayWebsocket:
large=large, large=large,
current_shard=shard[0], current_shard=shard[0],
shard_count=shard[1], shard_count=shard[1],
intents=intents,
) )
self.state.ws = self self.state.ws = self
@ -703,7 +712,7 @@ class GatewayWebsocket:
settings = await self.user_storage.get_user_settings(user_id) settings = await self.user_storage.get_user_settings(user_id)
await self.update_presence(presence, settings=settings) await self.update_presence(presence, settings=settings)
await self.subscribe_all(data.get("guild_subscriptions", True)) await self.subscribe_all()
await self.dispatch_ready(settings=settings) await self.dispatch_ready(settings=settings)
async def handle_3(self, payload: Dict[str, Any]): async def handle_3(self, payload: Dict[str, Any]):

View File

@ -18,14 +18,13 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List from typing import List
from dataclasses import dataclass
from quart import current_app as app from quart import current_app as app
from logbook import Logger from logbook import Logger
from litecord.enums import ChannelType from litecord.enums import ChannelType, EVENTS_TO_INTENTS
from litecord.utils import index_by_func from litecord.utils import index_by_func
from .dispatcher import DispatcherWithFlags, GatewayEvent from .dispatcher import DispatcherWithState, GatewayEvent
log = Logger(__name__) log = Logger(__name__)
@ -44,14 +43,7 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
return data return data
@dataclass class ChannelDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
class ChannelFlags:
typing: bool
class ChannelDispatcher(
DispatcherWithFlags[int, str, GatewayEvent, List[str], ChannelFlags]
):
"""Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels.""" """Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels."""
async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]: async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]:
@ -69,14 +61,11 @@ class ChannelDispatcher(
await self.unsub(channel_id, session_id) await self.unsub(channel_id, session_id)
continue continue
try: wanted_intent = EVENTS_TO_INTENTS.get(event_type)
flags = self.get_flags(channel_id, session_id) if wanted_intent is not None:
except KeyError: state_has_intent = (state.intents & wanted_intent) == wanted_intent
log.warning("no flags for {!r}, ignoring", session_id) if not state_has_intent:
flags = ChannelFlags(typing=True) continue
if event_type.lower().startswith("typing_") and not flags.typing:
continue
correct_event = event correct_event = event
# for cases where we are talking about group dms, we create an edited # for cases where we are talking about group dms, we create an edited

View File

@ -45,7 +45,7 @@ F_Map = Mapping[V, F]
GatewayEvent = Tuple[str, Any] GatewayEvent = Tuple[str, Any]
__all__ = ["Dispatcher", "DispatcherWithState", "DispatcherWithFlags", "GatewayEvent"] __all__ = ["Dispatcher", "DispatcherWithState", "GatewayEvent"]
class Dispatcher(Generic[K, V, EventType, DispatchType]): class Dispatcher(Generic[K, V, EventType, DispatchType]):
@ -123,39 +123,3 @@ class DispatcherWithState(Dispatcher[K, V, EventType, DispatchType]):
self.state.pop(key) self.state.pop(key)
except KeyError: except KeyError:
pass pass
class DispatcherWithFlags(
DispatcherWithState,
Generic[K, V, EventType, DispatchType, F],
):
"""Pub/Sub backend with both a state and a flags store."""
def __init__(self):
super().__init__()
self.flags: Mapping[K, Dict[V, F]] = defaultdict(dict)
def set_flags(self, key: K, identifier: V, flags: F):
"""Set flags for the given identifier."""
self.flags[key][identifier] = flags
def remove_flags(self, key: K, identifier: V):
"""Set flags for the given identifier."""
try:
self.flags[key].pop(identifier)
except KeyError:
pass
def get_flags(self, key: K, identifier: V):
"""Get a single field from the flags store."""
return self.flags[key][identifier]
async def sub_with_flags(self, key: K, identifier: V, flags: F):
"""Subscribe a user to the guild."""
await super().sub(key, identifier)
self.set_flags(key, identifier, flags)
async def unsub(self, key: K, identifier: V):
"""Unsubscribe a user from the guild."""
await super().unsub(key, identifier)
self.remove_flags(key, identifier)

View File

@ -18,26 +18,34 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """
from typing import List from typing import List
from dataclasses import dataclass
from quart import current_app as app from quart import current_app as app
from logbook import Logger from logbook import Logger
from .dispatcher import DispatcherWithFlags, GatewayEvent from .dispatcher import DispatcherWithState, GatewayEvent
from .channel import ChannelFlags
from litecord.gateway.state import GatewayState from litecord.gateway.state import GatewayState
from litecord.enums import EVENTS_TO_INTENTS
log = Logger(__name__) log = Logger(__name__)
@dataclass def can_dispatch(event_type, event_data, state) -> bool:
class GuildFlags(ChannelFlags): # If we're sending to the same user for this kind of event,
presence: bool # bypass event logic (always send)
if event_type == "GUILD_MEMBER_UPDATE":
user_id = int(event_data["user"])
return user_id == state.user_id
# TODO Guild Create and Req Guild Members have specific
# logic regarding the presence intent.
wanted_intent = EVENTS_TO_INTENTS.get(event_type)
if wanted_intent is not None:
state_has_intent = (state.intents & wanted_intent) == wanted_intent
return state_has_intent
class GuildDispatcher( class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
DispatcherWithFlags[int, str, GatewayEvent, List[str], GuildFlags]
):
"""Guild backend for Pub/Sub.""" """Guild backend for Pub/Sub."""
async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]: async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]:
@ -52,7 +60,7 @@ class GuildDispatcher(
): ):
session_ids = self.state[guild_id] session_ids = self.state[guild_id]
sessions: List[str] = [] sessions: List[str] = []
event_type, _ = event event_type, event_data = event
for session_id in set(session_ids): for session_id in set(session_ids):
if not filter_function(session_id): if not filter_function(session_id):
@ -68,13 +76,7 @@ class GuildDispatcher(
await self.unsub(guild_id, session_id) await self.unsub(guild_id, session_id)
continue continue
try: if not can_dispatch(event_type, event_data, state):
flags = self.get_flags(guild_id, session_id)
except KeyError:
log.warning("no flags for {!r}, ignoring", session_id)
flags = GuildFlags(presence=True, typing=True)
if event_type.lower().startswith("presence_") and not flags.presence:
continue continue
try: try: