diff --git a/litecord/enums.py b/litecord/enums.py index 78a394b..0d2c0c7 100644 --- a/litecord/enums.py +++ b/litecord/enums.py @@ -19,7 +19,7 @@ along with this program. If not, see . import inspect from typing import List, Any -from enum import Enum, IntEnum +from enum import Enum, IntEnum, IntFlag class EasyEnum(Enum): @@ -66,7 +66,7 @@ class Flags: for attr, val in cls._attrs: has_attr = (value & val) == val # set attributes dynamically - setattr(res, f"is_{attr}", has_attr) + setattr(res, f"is_{attr.lower()}", has_attr) return res @@ -247,3 +247,76 @@ class Feature(EasyEnum): # unknown commerce = "COMMERCE" news = "NEWS" + + +class Intents(IntFlag): + GUILDS = 1 << 0 + GUILD_MEMBERS = 1 << 1 + GUILD_BANS = 1 << 2 + GUILD_EMOJIS = 1 << 3 + GUILD_INTEGRATIONS = 1 << 4 + GUILD_WEBHOOKS = 1 << 5 + GUILD_INVITES = 1 << 6 + GUILD_VOICE_STATES = 1 << 7 + GUILD_PRESENCES = 1 << 8 + GUILD_MESSAGES = 1 << 9 + GUILD_MESSAGE_REACTIONS = 1 << 10 + GUILD_MESSAGE_TYPING = 1 << 11 + DIRECT_MESSAGES = 1 << 12 + DIRECT_MESSAGE_REACTIONS = 1 << 13 + DIRECT_MESSAGE_TYPING = 1 << 14 + + @classmethod + def default(cls): + return cls(-1) + + +EVENTS_TO_INTENTS = { + "GUILD_CREATE": Intents.GUILDS, + "GUILD_UPDATE": Intents.GUILDS, + "GUILD_DELETE": Intents.GUILDS, + "GUILD_ROLE_CREATE": Intents.GUILDS, + "GUILD_ROLE_UPDATE": Intents.GUILDS, + "GUILD_ROLE_DELETE": Intents.GUILDS, + "CHANNEL_CREATE": Intents.GUILDS, + "CHANNEL_UPDATE": Intents.GUILDS, + "CHANNEL_DELETE": Intents.GUILDS, + "CHANNEL_PINS_UPDATE": Intents.GUILDS, + # --- threads not supported -- + "THREAD_CREATE": Intents.GUILDS, + "THREAD_UPDATE": Intents.GUILDS, + "THREAD_DELETE": Intents.GUILDS, + "THREAD_LIST_SYNC": Intents.GUILDS, + "THREAD_MEMBER_UPDATE": Intents.GUILDS, + "THREAD_MEMBERS_UPDATE": Intents.GUILDS, + # --- stages not supported -- + "STAGE_INSTANCE_CREATE": Intents.GUILDS, + "STAGE_INSTANCE_UPDATE": Intents.GUILDS, + "STAGE_INSTANCE_DELETE": Intents.GUILDS, + "GUILD_MEMBER_ADD": Intents.GUILD_MEMBERS, + "GUILD_MEMBER_UPDATE": Intents.GUILD_MEMBERS, + "GUILD_MEMBER_REMOVE": Intents.GUILD_MEMBERS, + # --- threads not supported -- + "THREAD_MEMBERS_UPDATE ": Intents.GUILD_MEMBERS, + "GUILD_BAN_ADD": Intents.GUILD_BANS, + "GUILD_BAN_REMOVE": Intents.GUILD_BANS, + "GUILD_EMOJIS_UPDATE": Intents.GUILD_EMOJIS, + "GUILD_INTEGRATIONS_UPDATE": Intents.GUILD_INTEGRATIONS, + "INTEGRATION_CREATE": Intents.GUILD_INTEGRATIONS, + "INTEGRATION_UPDATE": Intents.GUILD_INTEGRATIONS, + "INTEGRATION_DELETE": Intents.GUILD_INTEGRATIONS, + "WEBHOOKS_UPDATE": Intents.GUILD_WEBHOOKS, + "INVITE_CREATE": Intents.GUILD_INVITES, + "INVITE_DELETE": Intents.GUILD_INVITES, + "VOICE_STATE_UPDATE": Intents.GUILD_VOICE_STATES, + "PRESENCE_UPDATE": Intents.GUILD_PRESENCES, + "MESSAGE_CREATE": Intents.GUILD_MESSAGES, + "MESSAGE_UPDATE": Intents.GUILD_MESSAGES, + "MESSAGE_DELETE": Intents.GUILD_MESSAGES, + "MESSAGE_DELETE_BULK": Intents.GUILD_MESSAGES, + "MESSAGE_REACTION_ADD": Intents.GUILD_MESSAGE_REACTIONS, + "MESSAGE_REACTION_REMOVE": Intents.GUILD_MESSAGE_REACTIONS, + "MESSAGE_REACTION_REMOVE_ALL": Intents.GUILD_MESSAGE_REACTIONS, + "MESSAGE_REACTION_REMOVE_EMOJI": Intents.GUILD_MESSAGE_REACTIONS, + "TYPING_START": Intents.GUILD_MESSAGE_TYPING, +} diff --git a/litecord/gateway/schemas.py b/litecord/gateway/schemas.py index 4928a37..ed8d12a 100644 --- a/litecord/gateway/schemas.py +++ b/litecord/gateway/schemas.py @@ -21,9 +21,9 @@ from typing import Dict from logbook import Logger - from litecord.gateway.errors import DecodeError from litecord.schemas import LitecordValidator +from litecord.enums import Intents log = Logger(__name__) @@ -64,6 +64,7 @@ IDENTIFY_SCHEMA = { "large_threshold": {"type": "number", "required": False}, "shard": {"type": "list", "required": False}, "presence": {"type": "dict", "required": False}, + "intents": {"coerce": Intents, "required": False}, }, } }, diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index eb2925e..5a17532 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -22,6 +22,7 @@ import os from typing import Optional from litecord.presence import BasePresence +from litecord.enums import Intents def gen_session_id() -> str: @@ -93,6 +94,7 @@ class GatewayState: self.compress: bool = kwargs.get("compress") or False self.large: int = kwargs.get("large") or 50 + self.intents: Intents = kwargs["intents"] def __bool__(self): """Return if the given state is a valid state to be used.""" diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index e464cfd..bbfdc2c 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -28,6 +28,7 @@ from logbook import Logger from litecord.gateway.state import GatewayState from litecord.gateway.opcodes import OP +from litecord.enums import Intents log = Logger(__name__) @@ -174,6 +175,7 @@ class StateManager: "game": None, "since": 0, }, + intents=Intents.default(), ) states.append(dummy_state) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 8e58c28..a0a7fef 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -31,7 +31,7 @@ from logbook import Logger from quart import current_app as app from litecord.auth import raw_token_check -from litecord.enums import RelationshipType, ChannelType, ActivityType +from litecord.enums import RelationshipType, ChannelType, ActivityType, Intents from litecord.utils import ( task_wrapper, yield_chunks, @@ -56,8 +56,6 @@ from litecord.gateway.errors import ( ) from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf from litecord.gateway.utils import WebsocketFileHandler -from litecord.pubsub.guild import GuildFlags -from litecord.pubsub.channel import ChannelFlags from litecord.gateway.schemas import ( validate, IDENTIFY_SCHEMA, @@ -101,6 +99,20 @@ def _complete_users_list(user_id: str, base_ready, user_ready) -> dict: return ready +def calculate_intents(data) -> Intents: + intents_int = data.get("intents") + guild_subscriptions = data.get("guild_subscriptions") + if guild_subscriptions is False and intents_int is None: + intents_int = Intents(0) + intents_int |= Intents.GUILD_MESSAGE_TYPING + intents_int |= Intents.DIRECT_MESSAGE_TYPING + intents_int = ~intents_int + elif intents_int is None: + intents_int = Intents.default() + + return Intents(intents_int) + + class GatewayWebsocket: """Main gateway websocket logic.""" @@ -460,7 +472,7 @@ class GatewayWebsocket: return list(filtered) - async def subscribe_all(self, guild_subscriptions: bool): + async def subscribe_all(self): """Subscribe to all guilds, DM channels, and friends. Note: subscribing to channels is already handled @@ -494,11 +506,7 @@ class GatewayWebsocket: channel_ids: List[int] = [] for guild_id in guild_ids: - await app.dispatcher.guild.sub_with_flags( - guild_id, - session_id, - GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions), - ) + await app.dispatcher.guild.sub(guild_id, session_id) # instead of calculating which channels to subscribe to # inside guild dispatcher, we calculate them in here, so that @@ -515,9 +523,7 @@ class GatewayWebsocket: log.info("subscribing to {} guild channels", len(channel_ids)) for channel_id in channel_ids: - await app.dispatcher.channel.sub_with_flags( - channel_id, session_id, ChannelFlags(typing=guild_subscriptions) - ) + await app.dispatcher.channel.sub(channel_id, session_id) for dm_id in dm_ids: await app.dispatcher.channel.sub(dm_id, session_id) @@ -668,6 +674,8 @@ class GatewayWebsocket: shard = data.get("shard", [0, 1]) presence = data.get("presence") or {} + intents = calculate_intents(data) + try: user_id = await raw_token_check(token, self.app.db) except (Unauthorized, Forbidden): @@ -693,6 +701,7 @@ class GatewayWebsocket: large=large, current_shard=shard[0], shard_count=shard[1], + intents=intents, ) self.state.ws = self @@ -703,7 +712,7 @@ class GatewayWebsocket: settings = await self.user_storage.get_user_settings(user_id) await self.update_presence(presence, settings=settings) - await self.subscribe_all(data.get("guild_subscriptions", True)) + await self.subscribe_all() await self.dispatch_ready(settings=settings) async def handle_3(self, payload: Dict[str, Any]): diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index 483b94f..9db8d7b 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -18,14 +18,13 @@ along with this program. If not, see . """ from typing import List -from dataclasses import dataclass from quart import current_app as app from logbook import Logger -from litecord.enums import ChannelType +from litecord.enums import ChannelType, EVENTS_TO_INTENTS from litecord.utils import index_by_func -from .dispatcher import DispatcherWithFlags, GatewayEvent +from .dispatcher import DispatcherWithState, GatewayEvent log = Logger(__name__) @@ -44,14 +43,7 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict: return data -@dataclass -class ChannelFlags: - typing: bool - - -class ChannelDispatcher( - DispatcherWithFlags[int, str, GatewayEvent, List[str], ChannelFlags] -): +class ChannelDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]): """Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels.""" async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]: @@ -69,14 +61,11 @@ class ChannelDispatcher( await self.unsub(channel_id, session_id) continue - try: - flags = self.get_flags(channel_id, session_id) - except KeyError: - log.warning("no flags for {!r}, ignoring", session_id) - flags = ChannelFlags(typing=True) - - if event_type.lower().startswith("typing_") and not flags.typing: - continue + wanted_intent = EVENTS_TO_INTENTS.get(event_type) + if wanted_intent is not None: + state_has_intent = (state.intents & wanted_intent) == wanted_intent + if not state_has_intent: + continue correct_event = event # for cases where we are talking about group dms, we create an edited diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py index 376aea5..50e3ead 100644 --- a/litecord/pubsub/dispatcher.py +++ b/litecord/pubsub/dispatcher.py @@ -45,7 +45,7 @@ F_Map = Mapping[V, F] GatewayEvent = Tuple[str, Any] -__all__ = ["Dispatcher", "DispatcherWithState", "DispatcherWithFlags", "GatewayEvent"] +__all__ = ["Dispatcher", "DispatcherWithState", "GatewayEvent"] class Dispatcher(Generic[K, V, EventType, DispatchType]): @@ -123,39 +123,3 @@ class DispatcherWithState(Dispatcher[K, V, EventType, DispatchType]): self.state.pop(key) except KeyError: pass - - -class DispatcherWithFlags( - DispatcherWithState, - Generic[K, V, EventType, DispatchType, F], -): - """Pub/Sub backend with both a state and a flags store.""" - - def __init__(self): - super().__init__() - self.flags: Mapping[K, Dict[V, F]] = defaultdict(dict) - - def set_flags(self, key: K, identifier: V, flags: F): - """Set flags for the given identifier.""" - self.flags[key][identifier] = flags - - def remove_flags(self, key: K, identifier: V): - """Set flags for the given identifier.""" - try: - self.flags[key].pop(identifier) - except KeyError: - pass - - def get_flags(self, key: K, identifier: V): - """Get a single field from the flags store.""" - return self.flags[key][identifier] - - async def sub_with_flags(self, key: K, identifier: V, flags: F): - """Subscribe a user to the guild.""" - await super().sub(key, identifier) - self.set_flags(key, identifier, flags) - - async def unsub(self, key: K, identifier: V): - """Unsubscribe a user from the guild.""" - await super().unsub(key, identifier) - self.remove_flags(key, identifier) diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index 3747b90..995837f 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -18,26 +18,34 @@ along with this program. If not, see . """ from typing import List -from dataclasses import dataclass from quart import current_app as app from logbook import Logger -from .dispatcher import DispatcherWithFlags, GatewayEvent -from .channel import ChannelFlags +from .dispatcher import DispatcherWithState, GatewayEvent from litecord.gateway.state import GatewayState +from litecord.enums import EVENTS_TO_INTENTS log = Logger(__name__) -@dataclass -class GuildFlags(ChannelFlags): - presence: bool +def can_dispatch(event_type, event_data, state) -> bool: + # If we're sending to the same user for this kind of event, + # bypass event logic (always send) + if event_type == "GUILD_MEMBER_UPDATE": + user_id = int(event_data["user"]) + return user_id == state.user_id + + # TODO Guild Create and Req Guild Members have specific + # logic regarding the presence intent. + + wanted_intent = EVENTS_TO_INTENTS.get(event_type) + if wanted_intent is not None: + state_has_intent = (state.intents & wanted_intent) == wanted_intent + return state_has_intent -class GuildDispatcher( - DispatcherWithFlags[int, str, GatewayEvent, List[str], GuildFlags] -): +class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]): """Guild backend for Pub/Sub.""" async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]: @@ -52,7 +60,7 @@ class GuildDispatcher( ): session_ids = self.state[guild_id] sessions: List[str] = [] - event_type, _ = event + event_type, event_data = event for session_id in set(session_ids): if not filter_function(session_id): @@ -68,13 +76,7 @@ class GuildDispatcher( await self.unsub(guild_id, session_id) continue - try: - flags = self.get_flags(guild_id, session_id) - except KeyError: - log.warning("no flags for {!r}, ignoring", session_id) - flags = GuildFlags(presence=True, typing=True) - - if event_type.lower().startswith("presence_") and not flags.presence: + if not can_dispatch(event_type, event_data, state): continue try: