From 33c7db9cbb6d8629282d0e1d96fd04f0302b2aef Mon Sep 17 00:00:00 2001 From: Luna Date: Sun, 11 Jul 2021 01:21:46 -0300 Subject: [PATCH] gateway: add support for intents in identify input --- litecord/gateway/state.py | 2 ++ litecord/gateway/state_manager.py | 2 ++ litecord/gateway/websocket.py | 27 ++++++++++++++++++++++----- 3 files changed, 26 insertions(+), 5 deletions(-) diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index eb2925e..5a17532 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -22,6 +22,7 @@ import os from typing import Optional from litecord.presence import BasePresence +from litecord.enums import Intents def gen_session_id() -> str: @@ -93,6 +94,7 @@ class GatewayState: self.compress: bool = kwargs.get("compress") or False self.large: int = kwargs.get("large") or 50 + self.intents: Intents = kwargs["intents"] def __bool__(self): """Return if the given state is a valid state to be used.""" diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index e464cfd..bbfdc2c 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -28,6 +28,7 @@ from logbook import Logger from litecord.gateway.state import GatewayState from litecord.gateway.opcodes import OP +from litecord.enums import Intents log = Logger(__name__) @@ -174,6 +175,7 @@ class StateManager: "game": None, "since": 0, }, + intents=Intents.default(), ) states.append(dummy_state) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 8e58c28..06c31d8 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -31,7 +31,7 @@ from logbook import Logger from quart import current_app as app from litecord.auth import raw_token_check -from litecord.enums import RelationshipType, ChannelType, ActivityType +from litecord.enums import RelationshipType, ChannelType, ActivityType, Intents from litecord.utils import ( task_wrapper, yield_chunks, @@ -101,6 +101,20 @@ def _complete_users_list(user_id: str, base_ready, user_ready) -> dict: return ready +def calculate_intents(data) -> Intents: + intents_int = data.get("intents") + guild_subscriptions = data.get("guild_subscriptions") + if guild_subscriptions is False and intents_int is None: + intents_int = Intents(0) + intents_int |= Intents.GUILD_MESSAGE_TYPING + intents_int |= Intents.DIRECT_MESSAGE_TYPING + intents_int = ~intents_int + elif intents_int is None: + intents_int = Intents.default() + + return Intents(intents_int) + + class GatewayWebsocket: """Main gateway websocket logic.""" @@ -460,7 +474,7 @@ class GatewayWebsocket: return list(filtered) - async def subscribe_all(self, guild_subscriptions: bool): + async def subscribe_all(self): """Subscribe to all guilds, DM channels, and friends. Note: subscribing to channels is already handled @@ -497,7 +511,7 @@ class GatewayWebsocket: await app.dispatcher.guild.sub_with_flags( guild_id, session_id, - GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions), + GuildFlags(presence=True, typing=True), ) # instead of calculating which channels to subscribe to @@ -516,7 +530,7 @@ class GatewayWebsocket: log.info("subscribing to {} guild channels", len(channel_ids)) for channel_id in channel_ids: await app.dispatcher.channel.sub_with_flags( - channel_id, session_id, ChannelFlags(typing=guild_subscriptions) + channel_id, session_id, ChannelFlags(typing=True) ) for dm_id in dm_ids: @@ -668,6 +682,8 @@ class GatewayWebsocket: shard = data.get("shard", [0, 1]) presence = data.get("presence") or {} + intents = calculate_intents(data) + try: user_id = await raw_token_check(token, self.app.db) except (Unauthorized, Forbidden): @@ -693,6 +709,7 @@ class GatewayWebsocket: large=large, current_shard=shard[0], shard_count=shard[1], + intents=intents, ) self.state.ws = self @@ -703,7 +720,7 @@ class GatewayWebsocket: settings = await self.user_storage.get_user_settings(user_id) await self.update_presence(presence, settings=settings) - await self.subscribe_all(data.get("guild_subscriptions", True)) + await self.subscribe_all() await self.dispatch_ready(settings=settings) async def handle_3(self, payload: Dict[str, Any]):