diff --git a/litecord/enums.py b/litecord/enums.py
index 78a394b..0d2c0c7 100644
--- a/litecord/enums.py
+++ b/litecord/enums.py
@@ -19,7 +19,7 @@ along with this program. If not, see .
import inspect
from typing import List, Any
-from enum import Enum, IntEnum
+from enum import Enum, IntEnum, IntFlag
class EasyEnum(Enum):
@@ -66,7 +66,7 @@ class Flags:
for attr, val in cls._attrs:
has_attr = (value & val) == val
# set attributes dynamically
- setattr(res, f"is_{attr}", has_attr)
+ setattr(res, f"is_{attr.lower()}", has_attr)
return res
@@ -247,3 +247,76 @@ class Feature(EasyEnum):
# unknown
commerce = "COMMERCE"
news = "NEWS"
+
+
+class Intents(IntFlag):
+ GUILDS = 1 << 0
+ GUILD_MEMBERS = 1 << 1
+ GUILD_BANS = 1 << 2
+ GUILD_EMOJIS = 1 << 3
+ GUILD_INTEGRATIONS = 1 << 4
+ GUILD_WEBHOOKS = 1 << 5
+ GUILD_INVITES = 1 << 6
+ GUILD_VOICE_STATES = 1 << 7
+ GUILD_PRESENCES = 1 << 8
+ GUILD_MESSAGES = 1 << 9
+ GUILD_MESSAGE_REACTIONS = 1 << 10
+ GUILD_MESSAGE_TYPING = 1 << 11
+ DIRECT_MESSAGES = 1 << 12
+ DIRECT_MESSAGE_REACTIONS = 1 << 13
+ DIRECT_MESSAGE_TYPING = 1 << 14
+
+ @classmethod
+ def default(cls):
+ return cls(-1)
+
+
+EVENTS_TO_INTENTS = {
+ "GUILD_CREATE": Intents.GUILDS,
+ "GUILD_UPDATE": Intents.GUILDS,
+ "GUILD_DELETE": Intents.GUILDS,
+ "GUILD_ROLE_CREATE": Intents.GUILDS,
+ "GUILD_ROLE_UPDATE": Intents.GUILDS,
+ "GUILD_ROLE_DELETE": Intents.GUILDS,
+ "CHANNEL_CREATE": Intents.GUILDS,
+ "CHANNEL_UPDATE": Intents.GUILDS,
+ "CHANNEL_DELETE": Intents.GUILDS,
+ "CHANNEL_PINS_UPDATE": Intents.GUILDS,
+ # --- threads not supported --
+ "THREAD_CREATE": Intents.GUILDS,
+ "THREAD_UPDATE": Intents.GUILDS,
+ "THREAD_DELETE": Intents.GUILDS,
+ "THREAD_LIST_SYNC": Intents.GUILDS,
+ "THREAD_MEMBER_UPDATE": Intents.GUILDS,
+ "THREAD_MEMBERS_UPDATE": Intents.GUILDS,
+ # --- stages not supported --
+ "STAGE_INSTANCE_CREATE": Intents.GUILDS,
+ "STAGE_INSTANCE_UPDATE": Intents.GUILDS,
+ "STAGE_INSTANCE_DELETE": Intents.GUILDS,
+ "GUILD_MEMBER_ADD": Intents.GUILD_MEMBERS,
+ "GUILD_MEMBER_UPDATE": Intents.GUILD_MEMBERS,
+ "GUILD_MEMBER_REMOVE": Intents.GUILD_MEMBERS,
+ # --- threads not supported --
+ "THREAD_MEMBERS_UPDATE ": Intents.GUILD_MEMBERS,
+ "GUILD_BAN_ADD": Intents.GUILD_BANS,
+ "GUILD_BAN_REMOVE": Intents.GUILD_BANS,
+ "GUILD_EMOJIS_UPDATE": Intents.GUILD_EMOJIS,
+ "GUILD_INTEGRATIONS_UPDATE": Intents.GUILD_INTEGRATIONS,
+ "INTEGRATION_CREATE": Intents.GUILD_INTEGRATIONS,
+ "INTEGRATION_UPDATE": Intents.GUILD_INTEGRATIONS,
+ "INTEGRATION_DELETE": Intents.GUILD_INTEGRATIONS,
+ "WEBHOOKS_UPDATE": Intents.GUILD_WEBHOOKS,
+ "INVITE_CREATE": Intents.GUILD_INVITES,
+ "INVITE_DELETE": Intents.GUILD_INVITES,
+ "VOICE_STATE_UPDATE": Intents.GUILD_VOICE_STATES,
+ "PRESENCE_UPDATE": Intents.GUILD_PRESENCES,
+ "MESSAGE_CREATE": Intents.GUILD_MESSAGES,
+ "MESSAGE_UPDATE": Intents.GUILD_MESSAGES,
+ "MESSAGE_DELETE": Intents.GUILD_MESSAGES,
+ "MESSAGE_DELETE_BULK": Intents.GUILD_MESSAGES,
+ "MESSAGE_REACTION_ADD": Intents.GUILD_MESSAGE_REACTIONS,
+ "MESSAGE_REACTION_REMOVE": Intents.GUILD_MESSAGE_REACTIONS,
+ "MESSAGE_REACTION_REMOVE_ALL": Intents.GUILD_MESSAGE_REACTIONS,
+ "MESSAGE_REACTION_REMOVE_EMOJI": Intents.GUILD_MESSAGE_REACTIONS,
+ "TYPING_START": Intents.GUILD_MESSAGE_TYPING,
+}
diff --git a/litecord/gateway/schemas.py b/litecord/gateway/schemas.py
index 4928a37..ed8d12a 100644
--- a/litecord/gateway/schemas.py
+++ b/litecord/gateway/schemas.py
@@ -21,9 +21,9 @@ from typing import Dict
from logbook import Logger
-
from litecord.gateway.errors import DecodeError
from litecord.schemas import LitecordValidator
+from litecord.enums import Intents
log = Logger(__name__)
@@ -64,6 +64,7 @@ IDENTIFY_SCHEMA = {
"large_threshold": {"type": "number", "required": False},
"shard": {"type": "list", "required": False},
"presence": {"type": "dict", "required": False},
+ "intents": {"coerce": Intents, "required": False},
},
}
},
diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py
index eb2925e..5a17532 100644
--- a/litecord/gateway/state.py
+++ b/litecord/gateway/state.py
@@ -22,6 +22,7 @@ import os
from typing import Optional
from litecord.presence import BasePresence
+from litecord.enums import Intents
def gen_session_id() -> str:
@@ -93,6 +94,7 @@ class GatewayState:
self.compress: bool = kwargs.get("compress") or False
self.large: int = kwargs.get("large") or 50
+ self.intents: Intents = kwargs["intents"]
def __bool__(self):
"""Return if the given state is a valid state to be used."""
diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py
index e464cfd..bbfdc2c 100644
--- a/litecord/gateway/state_manager.py
+++ b/litecord/gateway/state_manager.py
@@ -28,6 +28,7 @@ from logbook import Logger
from litecord.gateway.state import GatewayState
from litecord.gateway.opcodes import OP
+from litecord.enums import Intents
log = Logger(__name__)
@@ -174,6 +175,7 @@ class StateManager:
"game": None,
"since": 0,
},
+ intents=Intents.default(),
)
states.append(dummy_state)
diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py
index 8e58c28..a0a7fef 100644
--- a/litecord/gateway/websocket.py
+++ b/litecord/gateway/websocket.py
@@ -31,7 +31,7 @@ from logbook import Logger
from quart import current_app as app
from litecord.auth import raw_token_check
-from litecord.enums import RelationshipType, ChannelType, ActivityType
+from litecord.enums import RelationshipType, ChannelType, ActivityType, Intents
from litecord.utils import (
task_wrapper,
yield_chunks,
@@ -56,8 +56,6 @@ from litecord.gateway.errors import (
)
from litecord.gateway.encoding import encode_json, decode_json, encode_etf, decode_etf
from litecord.gateway.utils import WebsocketFileHandler
-from litecord.pubsub.guild import GuildFlags
-from litecord.pubsub.channel import ChannelFlags
from litecord.gateway.schemas import (
validate,
IDENTIFY_SCHEMA,
@@ -101,6 +99,20 @@ def _complete_users_list(user_id: str, base_ready, user_ready) -> dict:
return ready
+def calculate_intents(data) -> Intents:
+ intents_int = data.get("intents")
+ guild_subscriptions = data.get("guild_subscriptions")
+ if guild_subscriptions is False and intents_int is None:
+ intents_int = Intents(0)
+ intents_int |= Intents.GUILD_MESSAGE_TYPING
+ intents_int |= Intents.DIRECT_MESSAGE_TYPING
+ intents_int = ~intents_int
+ elif intents_int is None:
+ intents_int = Intents.default()
+
+ return Intents(intents_int)
+
+
class GatewayWebsocket:
"""Main gateway websocket logic."""
@@ -460,7 +472,7 @@ class GatewayWebsocket:
return list(filtered)
- async def subscribe_all(self, guild_subscriptions: bool):
+ async def subscribe_all(self):
"""Subscribe to all guilds, DM channels, and friends.
Note: subscribing to channels is already handled
@@ -494,11 +506,7 @@ class GatewayWebsocket:
channel_ids: List[int] = []
for guild_id in guild_ids:
- await app.dispatcher.guild.sub_with_flags(
- guild_id,
- session_id,
- GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions),
- )
+ await app.dispatcher.guild.sub(guild_id, session_id)
# instead of calculating which channels to subscribe to
# inside guild dispatcher, we calculate them in here, so that
@@ -515,9 +523,7 @@ class GatewayWebsocket:
log.info("subscribing to {} guild channels", len(channel_ids))
for channel_id in channel_ids:
- await app.dispatcher.channel.sub_with_flags(
- channel_id, session_id, ChannelFlags(typing=guild_subscriptions)
- )
+ await app.dispatcher.channel.sub(channel_id, session_id)
for dm_id in dm_ids:
await app.dispatcher.channel.sub(dm_id, session_id)
@@ -668,6 +674,8 @@ class GatewayWebsocket:
shard = data.get("shard", [0, 1])
presence = data.get("presence") or {}
+ intents = calculate_intents(data)
+
try:
user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden):
@@ -693,6 +701,7 @@ class GatewayWebsocket:
large=large,
current_shard=shard[0],
shard_count=shard[1],
+ intents=intents,
)
self.state.ws = self
@@ -703,7 +712,7 @@ class GatewayWebsocket:
settings = await self.user_storage.get_user_settings(user_id)
await self.update_presence(presence, settings=settings)
- await self.subscribe_all(data.get("guild_subscriptions", True))
+ await self.subscribe_all()
await self.dispatch_ready(settings=settings)
async def handle_3(self, payload: Dict[str, Any]):
diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py
index 483b94f..9db8d7b 100644
--- a/litecord/pubsub/channel.py
+++ b/litecord/pubsub/channel.py
@@ -18,14 +18,13 @@ along with this program. If not, see .
"""
from typing import List
-from dataclasses import dataclass
from quart import current_app as app
from logbook import Logger
-from litecord.enums import ChannelType
+from litecord.enums import ChannelType, EVENTS_TO_INTENTS
from litecord.utils import index_by_func
-from .dispatcher import DispatcherWithFlags, GatewayEvent
+from .dispatcher import DispatcherWithState, GatewayEvent
log = Logger(__name__)
@@ -44,14 +43,7 @@ def gdm_recipient_view(orig: dict, user_id: int) -> dict:
return data
-@dataclass
-class ChannelFlags:
- typing: bool
-
-
-class ChannelDispatcher(
- DispatcherWithFlags[int, str, GatewayEvent, List[str], ChannelFlags]
-):
+class ChannelDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
"""Main channel Pub/Sub logic. Handles both Guild, DM, and Group DM channels."""
async def dispatch(self, channel_id: int, event: GatewayEvent) -> List[str]:
@@ -69,14 +61,11 @@ class ChannelDispatcher(
await self.unsub(channel_id, session_id)
continue
- try:
- flags = self.get_flags(channel_id, session_id)
- except KeyError:
- log.warning("no flags for {!r}, ignoring", session_id)
- flags = ChannelFlags(typing=True)
-
- if event_type.lower().startswith("typing_") and not flags.typing:
- continue
+ wanted_intent = EVENTS_TO_INTENTS.get(event_type)
+ if wanted_intent is not None:
+ state_has_intent = (state.intents & wanted_intent) == wanted_intent
+ if not state_has_intent:
+ continue
correct_event = event
# for cases where we are talking about group dms, we create an edited
diff --git a/litecord/pubsub/dispatcher.py b/litecord/pubsub/dispatcher.py
index 376aea5..50e3ead 100644
--- a/litecord/pubsub/dispatcher.py
+++ b/litecord/pubsub/dispatcher.py
@@ -45,7 +45,7 @@ F_Map = Mapping[V, F]
GatewayEvent = Tuple[str, Any]
-__all__ = ["Dispatcher", "DispatcherWithState", "DispatcherWithFlags", "GatewayEvent"]
+__all__ = ["Dispatcher", "DispatcherWithState", "GatewayEvent"]
class Dispatcher(Generic[K, V, EventType, DispatchType]):
@@ -123,39 +123,3 @@ class DispatcherWithState(Dispatcher[K, V, EventType, DispatchType]):
self.state.pop(key)
except KeyError:
pass
-
-
-class DispatcherWithFlags(
- DispatcherWithState,
- Generic[K, V, EventType, DispatchType, F],
-):
- """Pub/Sub backend with both a state and a flags store."""
-
- def __init__(self):
- super().__init__()
- self.flags: Mapping[K, Dict[V, F]] = defaultdict(dict)
-
- def set_flags(self, key: K, identifier: V, flags: F):
- """Set flags for the given identifier."""
- self.flags[key][identifier] = flags
-
- def remove_flags(self, key: K, identifier: V):
- """Set flags for the given identifier."""
- try:
- self.flags[key].pop(identifier)
- except KeyError:
- pass
-
- def get_flags(self, key: K, identifier: V):
- """Get a single field from the flags store."""
- return self.flags[key][identifier]
-
- async def sub_with_flags(self, key: K, identifier: V, flags: F):
- """Subscribe a user to the guild."""
- await super().sub(key, identifier)
- self.set_flags(key, identifier, flags)
-
- async def unsub(self, key: K, identifier: V):
- """Unsubscribe a user from the guild."""
- await super().unsub(key, identifier)
- self.remove_flags(key, identifier)
diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py
index 3747b90..995837f 100644
--- a/litecord/pubsub/guild.py
+++ b/litecord/pubsub/guild.py
@@ -18,26 +18,34 @@ along with this program. If not, see .
"""
from typing import List
-from dataclasses import dataclass
from quart import current_app as app
from logbook import Logger
-from .dispatcher import DispatcherWithFlags, GatewayEvent
-from .channel import ChannelFlags
+from .dispatcher import DispatcherWithState, GatewayEvent
from litecord.gateway.state import GatewayState
+from litecord.enums import EVENTS_TO_INTENTS
log = Logger(__name__)
-@dataclass
-class GuildFlags(ChannelFlags):
- presence: bool
+def can_dispatch(event_type, event_data, state) -> bool:
+ # If we're sending to the same user for this kind of event,
+ # bypass event logic (always send)
+ if event_type == "GUILD_MEMBER_UPDATE":
+ user_id = int(event_data["user"])
+ return user_id == state.user_id
+
+ # TODO Guild Create and Req Guild Members have specific
+ # logic regarding the presence intent.
+
+ wanted_intent = EVENTS_TO_INTENTS.get(event_type)
+ if wanted_intent is not None:
+ state_has_intent = (state.intents & wanted_intent) == wanted_intent
+ return state_has_intent
-class GuildDispatcher(
- DispatcherWithFlags[int, str, GatewayEvent, List[str], GuildFlags]
-):
+class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
"""Guild backend for Pub/Sub."""
async def sub_user(self, guild_id: int, user_id: int) -> List[GatewayState]:
@@ -52,7 +60,7 @@ class GuildDispatcher(
):
session_ids = self.state[guild_id]
sessions: List[str] = []
- event_type, _ = event
+ event_type, event_data = event
for session_id in set(session_ids):
if not filter_function(session_id):
@@ -68,13 +76,7 @@ class GuildDispatcher(
await self.unsub(guild_id, session_id)
continue
- try:
- flags = self.get_flags(guild_id, session_id)
- except KeyError:
- log.warning("no flags for {!r}, ignoring", session_id)
- flags = GuildFlags(presence=True, typing=True)
-
- if event_type.lower().startswith("presence_") and not flags.presence:
+ if not can_dispatch(event_type, event_data, state):
continue
try: