gateway: add support for intents in identify input

This commit is contained in:
Luna 2021-07-11 01:21:46 -03:00
parent a12ef353d3
commit 33c7db9cbb
3 changed files with 26 additions and 5 deletions

View File

@ -22,6 +22,7 @@ import os
from typing import Optional from typing import Optional
from litecord.presence import BasePresence from litecord.presence import BasePresence
from litecord.enums import Intents
def gen_session_id() -> str: def gen_session_id() -> str:
@ -93,6 +94,7 @@ class GatewayState:
self.compress: bool = kwargs.get("compress") or False self.compress: bool = kwargs.get("compress") or False
self.large: int = kwargs.get("large") or 50 self.large: int = kwargs.get("large") or 50
self.intents: Intents = kwargs["intents"]
def __bool__(self): def __bool__(self):
"""Return if the given state is a valid state to be used.""" """Return if the given state is a valid state to be used."""

View File

@ -28,6 +28,7 @@ from logbook import Logger
from litecord.gateway.state import GatewayState from litecord.gateway.state import GatewayState
from litecord.gateway.opcodes import OP from litecord.gateway.opcodes import OP
from litecord.enums import Intents
log = Logger(__name__) log = Logger(__name__)
@ -174,6 +175,7 @@ class StateManager:
"game": None, "game": None,
"since": 0, "since": 0,
}, },
intents=Intents.default(),
) )
states.append(dummy_state) states.append(dummy_state)

View File

@ -31,7 +31,7 @@ from logbook import Logger
from quart import current_app as app from quart import current_app as app
from litecord.auth import raw_token_check from litecord.auth import raw_token_check
from litecord.enums import RelationshipType, ChannelType, ActivityType from litecord.enums import RelationshipType, ChannelType, ActivityType, Intents
from litecord.utils import ( from litecord.utils import (
task_wrapper, task_wrapper,
yield_chunks, yield_chunks,
@ -101,6 +101,20 @@ def _complete_users_list(user_id: str, base_ready, user_ready) -> dict:
return ready return ready
def calculate_intents(data) -> Intents:
intents_int = data.get("intents")
guild_subscriptions = data.get("guild_subscriptions")
if guild_subscriptions is False and intents_int is None:
intents_int = Intents(0)
intents_int |= Intents.GUILD_MESSAGE_TYPING
intents_int |= Intents.DIRECT_MESSAGE_TYPING
intents_int = ~intents_int
elif intents_int is None:
intents_int = Intents.default()
return Intents(intents_int)
class GatewayWebsocket: class GatewayWebsocket:
"""Main gateway websocket logic.""" """Main gateway websocket logic."""
@ -460,7 +474,7 @@ class GatewayWebsocket:
return list(filtered) return list(filtered)
async def subscribe_all(self, guild_subscriptions: bool): async def subscribe_all(self):
"""Subscribe to all guilds, DM channels, and friends. """Subscribe to all guilds, DM channels, and friends.
Note: subscribing to channels is already handled Note: subscribing to channels is already handled
@ -497,7 +511,7 @@ class GatewayWebsocket:
await app.dispatcher.guild.sub_with_flags( await app.dispatcher.guild.sub_with_flags(
guild_id, guild_id,
session_id, session_id,
GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions), GuildFlags(presence=True, typing=True),
) )
# instead of calculating which channels to subscribe to # instead of calculating which channels to subscribe to
@ -516,7 +530,7 @@ class GatewayWebsocket:
log.info("subscribing to {} guild channels", len(channel_ids)) log.info("subscribing to {} guild channels", len(channel_ids))
for channel_id in channel_ids: for channel_id in channel_ids:
await app.dispatcher.channel.sub_with_flags( await app.dispatcher.channel.sub_with_flags(
channel_id, session_id, ChannelFlags(typing=guild_subscriptions) channel_id, session_id, ChannelFlags(typing=True)
) )
for dm_id in dm_ids: for dm_id in dm_ids:
@ -668,6 +682,8 @@ class GatewayWebsocket:
shard = data.get("shard", [0, 1]) shard = data.get("shard", [0, 1])
presence = data.get("presence") or {} presence = data.get("presence") or {}
intents = calculate_intents(data)
try: try:
user_id = await raw_token_check(token, self.app.db) user_id = await raw_token_check(token, self.app.db)
except (Unauthorized, Forbidden): except (Unauthorized, Forbidden):
@ -693,6 +709,7 @@ class GatewayWebsocket:
large=large, large=large,
current_shard=shard[0], current_shard=shard[0],
shard_count=shard[1], shard_count=shard[1],
intents=intents,
) )
self.state.ws = self self.state.ws = self
@ -703,7 +720,7 @@ class GatewayWebsocket:
settings = await self.user_storage.get_user_settings(user_id) settings = await self.user_storage.get_user_settings(user_id)
await self.update_presence(presence, settings=settings) await self.update_presence(presence, settings=settings)
await self.subscribe_all(data.get("guild_subscriptions", True)) await self.subscribe_all()
await self.dispatch_ready(settings=settings) await self.dispatch_ready(settings=settings)
async def handle_3(self, payload: Dict[str, Any]): async def handle_3(self, payload: Dict[str, Any]):