Merge branch 'impl/gateway-intents' into 'master'

Add Gateway Intents support

Closes #74

See merge request litecord/litecord!74
This commit is contained in:
luna 2021-07-14 19:05:00 +00:00
commit d3806353b6
8 changed files with 131 additions and 89 deletions

View File

@ -19,7 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import inspect
from typing import List, Any
from enum import Enum, IntEnum
from enum import Enum, IntEnum, IntFlag
class EasyEnum(Enum):
@ -66,7 +66,7 @@ class Flags:
for attr, val in cls._attrs:
has_attr = (value & val) == val
# set attributes dynamically
setattr(res, f"is_{attr}", has_attr)
setattr(res, f"is_{attr.lower()}", has_attr)
return res
@ -247,3 +247,76 @@ class Feature(EasyEnum):
# unknown
commerce = "COMMERCE"
news = "NEWS"
class Intents(IntFlag):
GUILDS = 1 << 0
GUILD_MEMBERS = 1 << 1
GUILD_BANS = 1 << 2
GUILD_EMOJIS = 1 << 3
GUILD_INTEGRATIONS = 1 << 4
GUILD_WEBHOOKS = 1 << 5
GUILD_INVITES = 1 << 6
GUILD_VOICE_STATES = 1 << 7
GUILD_PRESENCES = 1 << 8
GUILD_MESSAGES = 1 << 9
GUILD_MESSAGE_REACTIONS = 1 << 10
GUILD_MESSAGE_TYPING = 1 << 11
DIRECT_MESSAGES = 1 << 12
DIRECT_MESSAGE_REACTIONS = 1 << 13
DIRECT_MESSAGE_TYPING = 1 << 14
@classmethod
def default(cls):
return cls(-1)
EVENTS_TO_INTENTS = {
"GUILD_CREATE": Intents.GUILDS,
"GUILD_UPDATE": Intents.GUILDS,
"GUILD_DELETE": Intents.GUILDS,
"GUILD_ROLE_CREATE": Intents.GUILDS,
"GUILD_ROLE_UPDATE": Intents.GUILDS,
"GUILD_ROLE_DELETE": Intents.GUILDS,
"CHANNEL_CREATE": Intents.GUILDS,
"CHANNEL_UPDATE": Intents.GUILDS,
"CHANNEL_DELETE": Intents.GUILDS,
"CHANNEL_PINS_UPDATE": Intents.GUILDS,
# --- threads not supported --
"THREAD_CREATE": Intents.GUILDS,
"THREAD_UPDATE": Intents.GUILDS,
"THREAD_DELETE": Intents.GUILDS,
"THREAD_LIST_SYNC": Intents.GUILDS,
"THREAD_MEMBER_UPDATE": Intents.GUILDS,
"THREAD_MEMBERS_UPDATE": Intents.GUILDS,
# --- stages not supported --
"STAGE_INSTANCE_CREATE": Intents.GUILDS,
"STAGE_INSTANCE_UPDATE": Intents.GUILDS,
"STAGE_INSTANCE_DELETE": Intents.GUILDS,
"GUILD_MEMBER_ADD": Intents.GUILD_MEMBERS,
"GUILD_MEMBER_UPDATE": Intents.GUILD_MEMBERS,
"GUILD_MEMBER_REMOVE": Intents.GUILD_MEMBERS,
# --- threads not supported --
"THREAD_MEMBERS_UPDATE ": Intents.GUILD_MEMBERS,
"GUILD_BAN_ADD": Intents.GUILD_BANS,
"GUILD_BAN_REMOVE": Intents.GUILD_BANS,
"GUILD_EMOJIS_UPDATE": Intents.GUILD_EMOJIS,
"GUILD_INTEGRATIONS_UPDATE": Intents.GUILD_INTEGRATIONS,
"INTEGRATION_CREATE": Intents.GUILD_INTEGRATIONS,
"INTEGRATION_UPDATE": Intents.GUILD_INTEGRATIONS,
"INTEGRATION_DELETE": Intents.GUILD_INTEGRATIONS,
"WEBHOOKS_UPDATE": Intents.GUILD_WEBHOOKS,
"INVITE_CREATE": Intents.GUILD_INVITES,
"INVITE_DELETE": Intents.GUILD_INVITES,
"VOICE_STATE_UPDATE": Intents.GUILD_VOICE_STATES,
"PRESENCE_UPDATE": Intents.GUILD_PRESENCES,
"MESSAGE_CREATE": Intents.GUILD_MESSAGES,
"MESSAGE_UPDATE": Intents.GUILD_MESSAGES,
"MESSAGE_DELETE": Intents.GUILD_MESSAGES,
"MESSAGE_DELETE_BULK": Intents.GUILD_MESSAGES,
"MESSAGE_REACTION_ADD": Intents.GUILD_MESSAGE_REACTIONS,
"MESSAGE_REACTION_REMOVE": Intents.GUILD_MESSAGE_REACTIONS,
"MESSAGE_REACTION_REMOVE_ALL": Intents.GUILD_MESSAGE_REACTIONS,
"MESSAGE_REACTION_REMOVE_EMOJI": Intents.GUILD_MESSAGE_REACTIONS,
"TYPING_START": Intents.GUILD_MESSAGE_TYPING,
}

View File

@ -21,9 +21,9 @@ from typing import Dict
from logbook import Logger
from litecord.gateway.errors import DecodeError
from litecord.schemas import LitecordValidator
from litecord.enums import Intents
log = Logger(__name__)
@ -64,6 +64,7 @@ IDENTIFY_SCHEMA = {
"large_threshold": {"type": "number", "required": False},
"shard": {"type": "list", "required": False},
"presence": {"type": "dict", "required": False},
"intents": {"coerce": Intents, "required": False},
},
}
},

View File

@ -22,6 +22,7 @@ import os
from typing import Optional
from litecord.presence import BasePresence
from litecord.enums import Intents
def gen_session_id() -> str:
@ -93,6 +94,7 @@ class GatewayState:
self.compress: bool = kwargs.get("compress") or False
self.large: int = kwargs.get("large") or 50
self.intents: Intents = kwargs["intents"]
def __bool__(self):
"""Return if the given state is a valid state to be used."""

View File

@ -28,6 +28,7 @@ from logbook import Logger
from litecord.gateway.state import GatewayState
from litecord.gateway.opcodes import OP
from litecord.enums import Intents
log = Logger(__name__)
@ -174,6 +175,7 @@ class StateManager:
"game": None,
"since": 0,
},
intents=Intents.default(),
)
states.append(dummy_state)

View File

@ -31,7 +31,7 @@ from logbook import Logger
from quart import current_app as app
from litecord.auth import raw_token_check
from litecord.enums import RelationshipType, ChannelType, ActivityType
from litecord.enums import RelationshipType, ChannelType, ActivityType, Intents
from litecord.utils import (
task_wrapper,
yield_chunks,
@ -56,8 +56,6 @@ from litecord.gateway.errors import (
)
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
from litecord.gateway.utils import WebsocketFileHandler
from litecord.pubsub.guild import GuildFlags
from litecord.pubsub.channel import ChannelFlags
from litecord.gateway.schemas import (
validate,
IDENTIFY_SCHEMA,
@ -101,6 +99,20 @@ def _complete_users_list(user_id: str, base_ready, user_ready) -> dict:
return ready
def calculate_intents(data) -> Intents:
intents_int = data.get("intents")
guild_subscriptions = data.get("guild_subscriptions")
if guild_subscriptions is False and intents_int is None:
intents_int = Intents(0)
intents_int |= Intents.GUILD_MESSAGE_TYPING
intents_int |= Intents.DIRECT_MESSAGE_TYPING
intents_int = ~intents_int
elif intents_int is None:
intents_int = Intents.default()
return Intents(intents_int)
class GatewayWebsocket:
"""Main gateway websocket logic."""
@ -460,7 +472,7 @@ class GatewayWebsocket:
return list(filtered)
async def subscribe_all(self, guild_subscriptions: bool):
async def subscribe_all(self):
"""Subscribe to all guilds, DM channels, and friends.
Note: subscribing to channels is already handled
@ -494,11 +506,7 @@ class GatewayWebsocket:
channel_ids: List[int] = []
for guild_id in guild_ids:
await app.dispatcher.guild.sub_with_flags(
guild_id,
session_id,
GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions),
)
await app.dispatcher.guild.sub(guild_id, session_id)
# instead of calculating which channels to subscribe to
# inside guild dispatcher, we calculate them in here, so that
@ -515,9 +523,7 @@ class GatewayWebsocket:
log.info("subscribing to {} guild channels", len(channel_ids))
for channel_id in channel_ids:
await app.dispatcher.channel.sub_with_flags(
channel_id, session_id, ChannelFlags(typing=guild_subscriptions)
)
await app.dispatcher.channel.sub(channel_id, session_id)
for dm_id in dm_ids:
await app.dispatcher.channel.sub(dm_id, session_id)
@ -668,6 +674,8 @@ class GatewayWebsocket:
shard = data.get("shard", [0, 1])
presence = data.get("presence") or {}
intents = calculate_intents(data)
try:
user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden):
@ -693,6 +701,7 @@ class GatewayWebsocket:
large=large,
current_shard=shard[0],
shard_count=shard[1],
intents=intents,
)
self.state.ws = self
@ -703,7 +712,7 @@ class GatewayWebsocket:
settings = await self.user_storage.get_user_settings(user_id)
await self.update_presence(presence, settings=settings)
await self.subscribe_all(data.get("guild_subscriptions", True))
await self.subscribe_all()
await self.dispatch_ready(settings=settings)
async def handle_3(self, payload: Dict[str, Any]):

View File

@ -18,14 +18,13 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import List
from dataclasses import dataclass
from quart import current_app as app
from logbook import Logger
from litecord.enums import ChannelType
from litecord.enums import ChannelType, EVENTS_TO_INTENTS
from litecord.utils import index_by_func
from .dispatcher import DispatcherWithFlags, GatewayEvent
from .dispatcher import DispatcherWithState, GatewayEvent
log = Logger(__name__)
@ -44,14 +43,7 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
return data
@dataclass
class ChannelFlags:
typing: bool
class ChannelDispatcher(
DispatcherWithFlags[int, str, GatewayEvent, List[str], ChannelFlags]
):
class ChannelDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
"""Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels."""
async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]:
@ -69,14 +61,11 @@ class ChannelDispatcher(
await self.unsub(channel_id, session_id)
continue
try:
flags = self.get_flags(channel_id, session_id)
except KeyError:
log.warning("no flags for {!r}, ignoring", session_id)
flags = ChannelFlags(typing=True)
if event_type.lower().startswith("typing_") and not flags.typing:
continue
wanted_intent = EVENTS_TO_INTENTS.get(event_type)
if wanted_intent is not None:
state_has_intent = (state.intents & wanted_intent) == wanted_intent
if not state_has_intent:
continue
correct_event = event
# for cases where we are talking about group dms, we create an edited

View File

@ -45,7 +45,7 @@ F_Map = Mapping[V, F]
GatewayEvent = Tuple[str, Any]
__all__ = ["Dispatcher", "DispatcherWithState", "DispatcherWithFlags", "GatewayEvent"]
__all__ = ["Dispatcher", "DispatcherWithState", "GatewayEvent"]
class Dispatcher(Generic[K, V, EventType, DispatchType]):
@ -123,39 +123,3 @@ class DispatcherWithState(Dispatcher[K, V, EventType, DispatchType]):
self.state.pop(key)
except KeyError:
pass
class DispatcherWithFlags(
DispatcherWithState,
Generic[K, V, EventType, DispatchType, F],
):
"""Pub/Sub backend with both a state and a flags store."""
def __init__(self):
super().__init__()
self.flags: Mapping[K, Dict[V, F]] = defaultdict(dict)
def set_flags(self, key: K, identifier: V, flags: F):
"""Set flags for the given identifier."""
self.flags[key][identifier] = flags
def remove_flags(self, key: K, identifier: V):
"""Set flags for the given identifier."""
try:
self.flags[key].pop(identifier)
except KeyError:
pass
def get_flags(self, key: K, identifier: V):
"""Get a single field from the flags store."""
return self.flags[key][identifier]
async def sub_with_flags(self, key: K, identifier: V, flags: F):
"""Subscribe a user to the guild."""
await super().sub(key, identifier)
self.set_flags(key, identifier, flags)
async def unsub(self, key: K, identifier: V):
"""Unsubscribe a user from the guild."""
await super().unsub(key, identifier)
self.remove_flags(key, identifier)

View File

@ -18,26 +18,34 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from typing import List
from dataclasses import dataclass
from quart import current_app as app
from logbook import Logger
from .dispatcher import DispatcherWithFlags, GatewayEvent
from .channel import ChannelFlags
from .dispatcher import DispatcherWithState, GatewayEvent
from litecord.gateway.state import GatewayState
from litecord.enums import EVENTS_TO_INTENTS
log = Logger(__name__)
@dataclass
class GuildFlags(ChannelFlags):
presence: bool
def can_dispatch(event_type, event_data, state) -> bool:
# If we're sending to the same user for this kind of event,
# bypass event logic (always send)
if event_type == "GUILD_MEMBER_UPDATE":
user_id = int(event_data["user"])
return user_id == state.user_id
# TODO Guild Create and Req Guild Members have specific
# logic regarding the presence intent.
wanted_intent = EVENTS_TO_INTENTS.get(event_type)
if wanted_intent is not None:
state_has_intent = (state.intents & wanted_intent) == wanted_intent
return state_has_intent
class GuildDispatcher(
DispatcherWithFlags[int, str, GatewayEvent, List[str], GuildFlags]
):
class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
"""Guild backend for Pub/Sub."""
async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]:
@ -52,7 +60,7 @@ class GuildDispatcher(
):
session_ids = self.state[guild_id]
sessions: List[str] = []
event_type, _ = event
event_type, event_data = event
for session_id in set(session_ids):
if not filter_function(session_id):
@ -68,13 +76,7 @@ class GuildDispatcher(
await self.unsub(guild_id, session_id)
continue
try:
flags = self.get_flags(guild_id, session_id)
except KeyError:
log.warning("no flags for {!r}, ignoring", session_id)
flags = GuildFlags(presence=True, typing=True)
if event_type.lower().startswith("presence_") and not flags.presence:
if not can_dispatch(event_type, event_data, state):
continue
try: