mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'impl/gateway-intents' into 'master'
Add Gateway Intents support Closes #74 See merge request litecord/litecord!74
This commit is contained in:
commit
d3806353b6
|
|
@ -19,7 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
import inspect
|
||||
from typing import List, Any
|
||||
from enum import Enum, IntEnum
|
||||
from enum import Enum, IntEnum, IntFlag
|
||||
|
||||
|
||||
class EasyEnum(Enum):
|
||||
|
|
@ -66,7 +66,7 @@ class Flags:
|
|||
for attr, val in cls._attrs:
|
||||
has_attr = (value & val) == val
|
||||
# set attributes dynamically
|
||||
setattr(res, f"is_{attr}", has_attr)
|
||||
setattr(res, f"is_{attr.lower()}", has_attr)
|
||||
|
||||
return res
|
||||
|
||||
|
|
@ -247,3 +247,76 @@ class Feature(EasyEnum):
|
|||
# unknown
|
||||
commerce = "COMMERCE"
|
||||
news = "NEWS"
|
||||
|
||||
|
||||
class Intents(IntFlag):
|
||||
GUILDS = 1 << 0
|
||||
GUILD_MEMBERS = 1 << 1
|
||||
GUILD_BANS = 1 << 2
|
||||
GUILD_EMOJIS = 1 << 3
|
||||
GUILD_INTEGRATIONS = 1 << 4
|
||||
GUILD_WEBHOOKS = 1 << 5
|
||||
GUILD_INVITES = 1 << 6
|
||||
GUILD_VOICE_STATES = 1 << 7
|
||||
GUILD_PRESENCES = 1 << 8
|
||||
GUILD_MESSAGES = 1 << 9
|
||||
GUILD_MESSAGE_REACTIONS = 1 << 10
|
||||
GUILD_MESSAGE_TYPING = 1 << 11
|
||||
DIRECT_MESSAGES = 1 << 12
|
||||
DIRECT_MESSAGE_REACTIONS = 1 << 13
|
||||
DIRECT_MESSAGE_TYPING = 1 << 14
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return cls(-1)
|
||||
|
||||
|
||||
EVENTS_TO_INTENTS = {
|
||||
"GUILD_CREATE": Intents.GUILDS,
|
||||
"GUILD_UPDATE": Intents.GUILDS,
|
||||
"GUILD_DELETE": Intents.GUILDS,
|
||||
"GUILD_ROLE_CREATE": Intents.GUILDS,
|
||||
"GUILD_ROLE_UPDATE": Intents.GUILDS,
|
||||
"GUILD_ROLE_DELETE": Intents.GUILDS,
|
||||
"CHANNEL_CREATE": Intents.GUILDS,
|
||||
"CHANNEL_UPDATE": Intents.GUILDS,
|
||||
"CHANNEL_DELETE": Intents.GUILDS,
|
||||
"CHANNEL_PINS_UPDATE": Intents.GUILDS,
|
||||
# --- threads not supported --
|
||||
"THREAD_CREATE": Intents.GUILDS,
|
||||
"THREAD_UPDATE": Intents.GUILDS,
|
||||
"THREAD_DELETE": Intents.GUILDS,
|
||||
"THREAD_LIST_SYNC": Intents.GUILDS,
|
||||
"THREAD_MEMBER_UPDATE": Intents.GUILDS,
|
||||
"THREAD_MEMBERS_UPDATE": Intents.GUILDS,
|
||||
# --- stages not supported --
|
||||
"STAGE_INSTANCE_CREATE": Intents.GUILDS,
|
||||
"STAGE_INSTANCE_UPDATE": Intents.GUILDS,
|
||||
"STAGE_INSTANCE_DELETE": Intents.GUILDS,
|
||||
"GUILD_MEMBER_ADD": Intents.GUILD_MEMBERS,
|
||||
"GUILD_MEMBER_UPDATE": Intents.GUILD_MEMBERS,
|
||||
"GUILD_MEMBER_REMOVE": Intents.GUILD_MEMBERS,
|
||||
# --- threads not supported --
|
||||
"THREAD_MEMBERS_UPDATE ": Intents.GUILD_MEMBERS,
|
||||
"GUILD_BAN_ADD": Intents.GUILD_BANS,
|
||||
"GUILD_BAN_REMOVE": Intents.GUILD_BANS,
|
||||
"GUILD_EMOJIS_UPDATE": Intents.GUILD_EMOJIS,
|
||||
"GUILD_INTEGRATIONS_UPDATE": Intents.GUILD_INTEGRATIONS,
|
||||
"INTEGRATION_CREATE": Intents.GUILD_INTEGRATIONS,
|
||||
"INTEGRATION_UPDATE": Intents.GUILD_INTEGRATIONS,
|
||||
"INTEGRATION_DELETE": Intents.GUILD_INTEGRATIONS,
|
||||
"WEBHOOKS_UPDATE": Intents.GUILD_WEBHOOKS,
|
||||
"INVITE_CREATE": Intents.GUILD_INVITES,
|
||||
"INVITE_DELETE": Intents.GUILD_INVITES,
|
||||
"VOICE_STATE_UPDATE": Intents.GUILD_VOICE_STATES,
|
||||
"PRESENCE_UPDATE": Intents.GUILD_PRESENCES,
|
||||
"MESSAGE_CREATE": Intents.GUILD_MESSAGES,
|
||||
"MESSAGE_UPDATE": Intents.GUILD_MESSAGES,
|
||||
"MESSAGE_DELETE": Intents.GUILD_MESSAGES,
|
||||
"MESSAGE_DELETE_BULK": Intents.GUILD_MESSAGES,
|
||||
"MESSAGE_REACTION_ADD": Intents.GUILD_MESSAGE_REACTIONS,
|
||||
"MESSAGE_REACTION_REMOVE": Intents.GUILD_MESSAGE_REACTIONS,
|
||||
"MESSAGE_REACTION_REMOVE_ALL": Intents.GUILD_MESSAGE_REACTIONS,
|
||||
"MESSAGE_REACTION_REMOVE_EMOJI": Intents.GUILD_MESSAGE_REACTIONS,
|
||||
"TYPING_START": Intents.GUILD_MESSAGE_TYPING,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,9 +21,9 @@ from typing import Dict
|
|||
|
||||
from logbook import Logger
|
||||
|
||||
|
||||
from litecord.gateway.errors import DecodeError
|
||||
from litecord.schemas import LitecordValidator
|
||||
from litecord.enums import Intents
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
|
@ -64,6 +64,7 @@ IDENTIFY_SCHEMA = {
|
|||
"large_threshold": {"type": "number", "required": False},
|
||||
"shard": {"type": "list", "required": False},
|
||||
"presence": {"type": "dict", "required": False},
|
||||
"intents": {"coerce": Intents, "required": False},
|
||||
},
|
||||
}
|
||||
},
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ import os
|
|||
|
||||
from typing import Optional
|
||||
from litecord.presence import BasePresence
|
||||
from litecord.enums import Intents
|
||||
|
||||
|
||||
def gen_session_id() -> str:
|
||||
|
|
@ -93,6 +94,7 @@ class GatewayState:
|
|||
self.compress: bool = kwargs.get("compress") or False
|
||||
|
||||
self.large: int = kwargs.get("large") or 50
|
||||
self.intents: Intents = kwargs["intents"]
|
||||
|
||||
def __bool__(self):
|
||||
"""Return if the given state is a valid state to be used."""
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from logbook import Logger
|
|||
|
||||
from litecord.gateway.state import GatewayState
|
||||
from litecord.gateway.opcodes import OP
|
||||
from litecord.enums import Intents
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -174,6 +175,7 @@ class StateManager:
|
|||
"game": None,
|
||||
"since": 0,
|
||||
},
|
||||
intents=Intents.default(),
|
||||
)
|
||||
|
||||
states.append(dummy_state)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from logbook import Logger
|
|||
from quart import current_app as app
|
||||
|
||||
from litecord.auth import raw_token_check
|
||||
from litecord.enums import RelationshipType, ChannelType, ActivityType
|
||||
from litecord.enums import RelationshipType, ChannelType, ActivityType, Intents
|
||||
from litecord.utils import (
|
||||
task_wrapper,
|
||||
yield_chunks,
|
||||
|
|
@ -56,8 +56,6 @@ from litecord.gateway.errors import (
|
|||
)
|
||||
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
|
||||
from litecord.gateway.utils import WebsocketFileHandler
|
||||
from litecord.pubsub.guild import GuildFlags
|
||||
from litecord.pubsub.channel import ChannelFlags
|
||||
from litecord.gateway.schemas import (
|
||||
validate,
|
||||
IDENTIFY_SCHEMA,
|
||||
|
|
@ -101,6 +99,20 @@ def _complete_users_list(user_id: str, base_ready, user_ready) -> dict:
|
|||
return ready
|
||||
|
||||
|
||||
def calculate_intents(data) -> Intents:
|
||||
intents_int = data.get("intents")
|
||||
guild_subscriptions = data.get("guild_subscriptions")
|
||||
if guild_subscriptions is False and intents_int is None:
|
||||
intents_int = Intents(0)
|
||||
intents_int |= Intents.GUILD_MESSAGE_TYPING
|
||||
intents_int |= Intents.DIRECT_MESSAGE_TYPING
|
||||
intents_int = ~intents_int
|
||||
elif intents_int is None:
|
||||
intents_int = Intents.default()
|
||||
|
||||
return Intents(intents_int)
|
||||
|
||||
|
||||
class GatewayWebsocket:
|
||||
"""Main gateway websocket logic."""
|
||||
|
||||
|
|
@ -460,7 +472,7 @@ class GatewayWebsocket:
|
|||
|
||||
return list(filtered)
|
||||
|
||||
async def subscribe_all(self, guild_subscriptions: bool):
|
||||
async def subscribe_all(self):
|
||||
"""Subscribe to all guilds, DM channels, and friends.
|
||||
|
||||
Note: subscribing to channels is already handled
|
||||
|
|
@ -494,11 +506,7 @@ class GatewayWebsocket:
|
|||
channel_ids: List[int] = []
|
||||
|
||||
for guild_id in guild_ids:
|
||||
await app.dispatcher.guild.sub_with_flags(
|
||||
guild_id,
|
||||
session_id,
|
||||
GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions),
|
||||
)
|
||||
await app.dispatcher.guild.sub(guild_id, session_id)
|
||||
|
||||
# instead of calculating which channels to subscribe to
|
||||
# inside guild dispatcher, we calculate them in here, so that
|
||||
|
|
@ -515,9 +523,7 @@ class GatewayWebsocket:
|
|||
|
||||
log.info("subscribing to {} guild channels", len(channel_ids))
|
||||
for channel_id in channel_ids:
|
||||
await app.dispatcher.channel.sub_with_flags(
|
||||
channel_id, session_id, ChannelFlags(typing=guild_subscriptions)
|
||||
)
|
||||
await app.dispatcher.channel.sub(channel_id, session_id)
|
||||
|
||||
for dm_id in dm_ids:
|
||||
await app.dispatcher.channel.sub(dm_id, session_id)
|
||||
|
|
@ -668,6 +674,8 @@ class GatewayWebsocket:
|
|||
shard = data.get("shard", [0, 1])
|
||||
presence = data.get("presence") or {}
|
||||
|
||||
intents = calculate_intents(data)
|
||||
|
||||
try:
|
||||
user_id = await raw_token_check(token, self.app.db)
|
||||
except (Unauthorized, Forbidden):
|
||||
|
|
@ -693,6 +701,7 @@ class GatewayWebsocket:
|
|||
large=large,
|
||||
current_shard=shard[0],
|
||||
shard_count=shard[1],
|
||||
intents=intents,
|
||||
)
|
||||
|
||||
self.state.ws = self
|
||||
|
|
@ -703,7 +712,7 @@ class GatewayWebsocket:
|
|||
settings = await self.user_storage.get_user_settings(user_id)
|
||||
|
||||
await self.update_presence(presence, settings=settings)
|
||||
await self.subscribe_all(data.get("guild_subscriptions", True))
|
||||
await self.subscribe_all()
|
||||
await self.dispatch_ready(settings=settings)
|
||||
|
||||
async def handle_3(self, payload: Dict[str, Any]):
|
||||
|
|
|
|||
|
|
@ -18,14 +18,13 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
from quart import current_app as app
|
||||
from logbook import Logger
|
||||
|
||||
from litecord.enums import ChannelType
|
||||
from litecord.enums import ChannelType, EVENTS_TO_INTENTS
|
||||
from litecord.utils import index_by_func
|
||||
from .dispatcher import DispatcherWithFlags, GatewayEvent
|
||||
from .dispatcher import DispatcherWithState, GatewayEvent
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
|
@ -44,14 +43,7 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
|
|||
return data
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChannelFlags:
|
||||
typing: bool
|
||||
|
||||
|
||||
class ChannelDispatcher(
|
||||
DispatcherWithFlags[int, str, GatewayEvent, List[str], ChannelFlags]
|
||||
):
|
||||
class ChannelDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
|
||||
"""Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels."""
|
||||
|
||||
async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]:
|
||||
|
|
@ -69,13 +61,10 @@ class ChannelDispatcher(
|
|||
await self.unsub(channel_id, session_id)
|
||||
continue
|
||||
|
||||
try:
|
||||
flags = self.get_flags(channel_id, session_id)
|
||||
except KeyError:
|
||||
log.warning("no flags for {!r}, ignoring", session_id)
|
||||
flags = ChannelFlags(typing=True)
|
||||
|
||||
if event_type.lower().startswith("typing_") and not flags.typing:
|
||||
wanted_intent = EVENTS_TO_INTENTS.get(event_type)
|
||||
if wanted_intent is not None:
|
||||
state_has_intent = (state.intents & wanted_intent) == wanted_intent
|
||||
if not state_has_intent:
|
||||
continue
|
||||
|
||||
correct_event = event
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ F_Map = Mapping[V, F]
|
|||
|
||||
GatewayEvent = Tuple[str, Any]
|
||||
|
||||
__all__ = ["Dispatcher", "DispatcherWithState", "DispatcherWithFlags", "GatewayEvent"]
|
||||
__all__ = ["Dispatcher", "DispatcherWithState", "GatewayEvent"]
|
||||
|
||||
|
||||
class Dispatcher(Generic[K, V, EventType, DispatchType]):
|
||||
|
|
@ -123,39 +123,3 @@ class DispatcherWithState(Dispatcher[K, V, EventType, DispatchType]):
|
|||
self.state.pop(key)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
||||
class DispatcherWithFlags(
|
||||
DispatcherWithState,
|
||||
Generic[K, V, EventType, DispatchType, F],
|
||||
):
|
||||
"""Pub/Sub backend with both a state and a flags store."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.flags: Mapping[K, Dict[V, F]] = defaultdict(dict)
|
||||
|
||||
def set_flags(self, key: K, identifier: V, flags: F):
|
||||
"""Set flags for the given identifier."""
|
||||
self.flags[key][identifier] = flags
|
||||
|
||||
def remove_flags(self, key: K, identifier: V):
|
||||
"""Set flags for the given identifier."""
|
||||
try:
|
||||
self.flags[key].pop(identifier)
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def get_flags(self, key: K, identifier: V):
|
||||
"""Get a single field from the flags store."""
|
||||
return self.flags[key][identifier]
|
||||
|
||||
async def sub_with_flags(self, key: K, identifier: V, flags: F):
|
||||
"""Subscribe a user to the guild."""
|
||||
await super().sub(key, identifier)
|
||||
self.set_flags(key, identifier, flags)
|
||||
|
||||
async def unsub(self, key: K, identifier: V):
|
||||
"""Unsubscribe a user from the guild."""
|
||||
await super().unsub(key, identifier)
|
||||
self.remove_flags(key, identifier)
|
||||
|
|
|
|||
|
|
@ -18,26 +18,34 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
|
||||
from quart import current_app as app
|
||||
from logbook import Logger
|
||||
|
||||
from .dispatcher import DispatcherWithFlags, GatewayEvent
|
||||
from .channel import ChannelFlags
|
||||
from .dispatcher import DispatcherWithState, GatewayEvent
|
||||
from litecord.gateway.state import GatewayState
|
||||
from litecord.enums import EVENTS_TO_INTENTS
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GuildFlags(ChannelFlags):
|
||||
presence: bool
|
||||
def can_dispatch(event_type, event_data, state) -> bool:
|
||||
# If we're sending to the same user for this kind of event,
|
||||
# bypass event logic (always send)
|
||||
if event_type == "GUILD_MEMBER_UPDATE":
|
||||
user_id = int(event_data["user"])
|
||||
return user_id == state.user_id
|
||||
|
||||
# TODO Guild Create and Req Guild Members have specific
|
||||
# logic regarding the presence intent.
|
||||
|
||||
wanted_intent = EVENTS_TO_INTENTS.get(event_type)
|
||||
if wanted_intent is not None:
|
||||
state_has_intent = (state.intents & wanted_intent) == wanted_intent
|
||||
return state_has_intent
|
||||
|
||||
|
||||
class GuildDispatcher(
|
||||
DispatcherWithFlags[int, str, GatewayEvent, List[str], GuildFlags]
|
||||
):
|
||||
class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
|
||||
"""Guild backend for Pub/Sub."""
|
||||
|
||||
async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]:
|
||||
|
|
@ -52,7 +60,7 @@ class GuildDispatcher(
|
|||
):
|
||||
session_ids = self.state[guild_id]
|
||||
sessions: List[str] = []
|
||||
event_type, _ = event
|
||||
event_type, event_data = event
|
||||
|
||||
for session_id in set(session_ids):
|
||||
if not filter_function(session_id):
|
||||
|
|
@ -68,13 +76,7 @@ class GuildDispatcher(
|
|||
await self.unsub(guild_id, session_id)
|
||||
continue
|
||||
|
||||
try:
|
||||
flags = self.get_flags(guild_id, session_id)
|
||||
except KeyError:
|
||||
log.warning("no flags for {!r}, ignoring", session_id)
|
||||
flags = GuildFlags(presence=True, typing=True)
|
||||
|
||||
if event_type.lower().startswith("presence_") and not flags.presence:
|
||||
if not can_dispatch(event_type, event_data, state):
|
||||
continue
|
||||
|
||||
try:
|
||||
|
|
|
|||
Loading…
Reference in New Issue