mirror of https://gitlab.com/litecord/litecord.git
gateway: add support for intents in identify input
This commit is contained in:
parent
a12ef353d3
commit
33c7db9cbb
|
|
@ -22,6 +22,7 @@ import os
|
|||
|
||||
from typing import Optional
|
||||
from litecord.presence import BasePresence
|
||||
from litecord.enums import Intents
|
||||
|
||||
|
||||
def gen_session_id() -> str:
|
||||
|
|
@ -93,6 +94,7 @@ class GatewayState:
|
|||
self.compress: bool = kwargs.get("compress") or False
|
||||
|
||||
self.large: int = kwargs.get("large") or 50
|
||||
self.intents: Intents = kwargs["intents"]
|
||||
|
||||
def __bool__(self):
|
||||
"""Return if the given state is a valid state to be used."""
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ from logbook import Logger
|
|||
|
||||
from litecord.gateway.state import GatewayState
|
||||
from litecord.gateway.opcodes import OP
|
||||
from litecord.enums import Intents
|
||||
|
||||
|
||||
log = Logger(__name__)
|
||||
|
|
@ -174,6 +175,7 @@ class StateManager:
|
|||
"game": None,
|
||||
"since": 0,
|
||||
},
|
||||
intents=Intents.default(),
|
||||
)
|
||||
|
||||
states.append(dummy_state)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from logbook import Logger
|
|||
from quart import current_app as app
|
||||
|
||||
from litecord.auth import raw_token_check
|
||||
from litecord.enums import RelationshipType, ChannelType, ActivityType
|
||||
from litecord.enums import RelationshipType, ChannelType, ActivityType, Intents
|
||||
from litecord.utils import (
|
||||
task_wrapper,
|
||||
yield_chunks,
|
||||
|
|
@ -101,6 +101,20 @@ def _complete_users_list(user_id: str, base_ready, user_ready) -> dict:
|
|||
return ready
|
||||
|
||||
|
||||
def calculate_intents(data) -> Intents:
|
||||
intents_int = data.get("intents")
|
||||
guild_subscriptions = data.get("guild_subscriptions")
|
||||
if guild_subscriptions is False and intents_int is None:
|
||||
intents_int = Intents(0)
|
||||
intents_int |= Intents.GUILD_MESSAGE_TYPING
|
||||
intents_int |= Intents.DIRECT_MESSAGE_TYPING
|
||||
intents_int = ~intents_int
|
||||
elif intents_int is None:
|
||||
intents_int = Intents.default()
|
||||
|
||||
return Intents(intents_int)
|
||||
|
||||
|
||||
class GatewayWebsocket:
|
||||
"""Main gateway websocket logic."""
|
||||
|
||||
|
|
@ -460,7 +474,7 @@ class GatewayWebsocket:
|
|||
|
||||
return list(filtered)
|
||||
|
||||
async def subscribe_all(self, guild_subscriptions: bool):
|
||||
async def subscribe_all(self):
|
||||
"""Subscribe to all guilds, DM channels, and friends.
|
||||
|
||||
Note: subscribing to channels is already handled
|
||||
|
|
@ -497,7 +511,7 @@ class GatewayWebsocket:
|
|||
await app.dispatcher.guild.sub_with_flags(
|
||||
guild_id,
|
||||
session_id,
|
||||
GuildFlags(presence=guild_subscriptions, typing=guild_subscriptions),
|
||||
GuildFlags(presence=True, typing=True),
|
||||
)
|
||||
|
||||
# instead of calculating which channels to subscribe to
|
||||
|
|
@ -516,7 +530,7 @@ class GatewayWebsocket:
|
|||
log.info("subscribing to {} guild channels", len(channel_ids))
|
||||
for channel_id in channel_ids:
|
||||
await app.dispatcher.channel.sub_with_flags(
|
||||
channel_id, session_id, ChannelFlags(typing=guild_subscriptions)
|
||||
channel_id, session_id, ChannelFlags(typing=True)
|
||||
)
|
||||
|
||||
for dm_id in dm_ids:
|
||||
|
|
@ -668,6 +682,8 @@ class GatewayWebsocket:
|
|||
shard = data.get("shard", [0, 1])
|
||||
presence = data.get("presence") or {}
|
||||
|
||||
intents = calculate_intents(data)
|
||||
|
||||
try:
|
||||
user_id = await raw_token_check(token, self.app.db)
|
||||
except (Unauthorized, Forbidden):
|
||||
|
|
@ -693,6 +709,7 @@ class GatewayWebsocket:
|
|||
large=large,
|
||||
current_shard=shard[0],
|
||||
shard_count=shard[1],
|
||||
intents=intents,
|
||||
)
|
||||
|
||||
self.state.ws = self
|
||||
|
|
@ -703,7 +720,7 @@ class GatewayWebsocket:
|
|||
settings = await self.user_storage.get_user_settings(user_id)
|
||||
|
||||
await self.update_presence(presence, settings=settings)
|
||||
await self.subscribe_all(data.get("guild_subscriptions", True))
|
||||
await self.subscribe_all()
|
||||
await self.dispatch_ready(settings=settings)
|
||||
|
||||
async def handle_3(self, payload: Dict[str, Any]):
|
||||
|
|
|
|||
Loading…
Reference in New Issue