diff --git a/litecord/common/guilds.py b/litecord/common/guilds.py index a8a960a..9212c2b 100644 --- a/litecord/common/guilds.py +++ b/litecord/common/guilds.py @@ -338,7 +338,4 @@ async def add_member(guild_id: int, user_id: int, *, basic=False): guild = await app.storage.get_guild_full(guild_id, user_id, 250) for state in states: - try: - await state.ws.dispatch("GUILD_CREATE", guild) - except Exception: - log.exception("failed to dispatch to session_id={!r}", state.session_id) + await state.dispatch("GUILD_CREATE", guild) diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index 5a17532..332f07c 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -19,10 +19,16 @@ along with this program. If not, see . import hashlib import os +from typing import Optional, Any + +import websockets +from logbook import Logger -from typing import Optional from litecord.presence import BasePresence from litecord.enums import Intents +from .opcodes import OP + +log = Logger(__name__) def gen_session_id() -> str: @@ -102,3 +108,30 @@ class GatewayState: def __repr__(self): return f"GatewayState" + + async def dispatch(self, event_type: str, event_data: Any) -> None: + """Dispatch an event to the underlying websocket. + + Stores the event in the state's payload store for resuming. + """ + self.seq += 1 + payload = { + "op": OP.DISPATCH, + "t": event_type.upper(), + "s": self.seq, + "d": event_data, + } + + self.store[self.seq] = payload + + log.debug("dispatching event {!r} to session {}", payload["t"], self.session_id) + + try: + await self.ws.send(payload) + except websockets.exceptions.ConnectionClosed as exc: + log.warning( + "Failed to dispatch {!r} to session id {}: {!r}", + payload["t"], + self.session_id, + exc, + ) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 59bfa24..89e4b81 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -392,11 +392,13 @@ class GatewayWebsocket: self._hb_start(interval) - async def dispatch(self, event: str, data: Any): - """Dispatch an event to the websocket.""" - assert self.state is not None - self.state.seq += 1 + async def dispatch_raw(self, event: str, data: Any): + """Dispatch an event to the websocket, bypassing the gateway state. + Only use this function for events related to connection state, + such as READY and RESUMED, or events that are replies to + messages in the websocket. + """ payload = { "op": OP.DISPATCH, "t": event.upper(), @@ -404,8 +406,6 @@ class GatewayWebsocket: "d": data, } - self.state.store[self.state.seq] = payload - log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id) try: @@ -451,7 +451,7 @@ class GatewayWebsocket: if guild is None: continue - await self.dispatch("GUILD_CREATE", guild) + await self.dispatch_raw("GUILD_CREATE", guild) async def _user_ready(self, *, settings=None) -> dict: """Fetch information about users in the READY packet. @@ -546,9 +546,9 @@ class GatewayWebsocket: for guild in full_ready_data["guilds"]: guild["members"] = [] - await self.dispatch("READY", full_ready_data) + await self.dispatch_raw("READY", full_ready_data) if self.ws_properties.version > 6: - await self.dispatch("READY_SUPPLEMENTAL", ready_supplemental) + await self.dispatch_raw("READY_SUPPLEMENTAL", ready_supplemental) app.sched.spawn(self._guild_dispatch(guilds)) async def _check_shards(self, shard, user_id): @@ -959,9 +959,9 @@ class GatewayWebsocket: return if presences: - await self.dispatch("PRESENCE_REPLACE", presences) + await self.dispatch_raw("PRESENCE_REPLACE", presences) - await self.dispatch("RESUMED", {}) + await self.dispatch_raw("RESUMED", {}) async def handle_6(self, payload: Dict[str, Any]): """Handle OP 6 Resume.""" @@ -1044,7 +1044,7 @@ class GatewayWebsocket: presences = await self.presence.guild_presences(mids, guild_id) body["presences"] = presences - await self.dispatch("GUILD_MEMBERS_CHUNK", body) + await self.dispatch_raw("GUILD_MEMBERS_CHUNK", body) async def handle_8(self, payload: Dict): """Handle OP 8 Request Guild Members.""" @@ -1088,7 +1088,7 @@ class GatewayWebsocket: log.debug(f"Syncing guild {guild_id} with {len(member_ids)} members") presences = await self.presence.guild_presences(member_ids, guild_id) - await self.dispatch( + await self.dispatch_raw( "GUILD_SYNC", {"id": str(guild_id), "presences": presences, "members": members}, ) @@ -1228,7 +1228,7 @@ class GatewayWebsocket: data = payload["d"] # stubbed - await self.dispatch( + await self.dispatch_raw( "GUILD_APPLICATION_COMMANDS_UPDATE", { "updated_at": 1630271377245, diff --git a/litecord/pubsub/channel.py b/litecord/pubsub/channel.py index 9db8d7b..4cdb9b2 100644 --- a/litecord/pubsub/channel.py +++ b/litecord/pubsub/channel.py @@ -78,12 +78,7 @@ class ChannelDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]): new_data = gdm_recipient_view(event_data, state.user_id) correct_event = (event_type, new_data) - try: - await state.ws.dispatch(*correct_event) - except Exception: - log.exception("error while dispatching to {}", state.session_id) - continue - + await state.dispatch(*correct_event) sessions.append(session_id) log.info( diff --git a/litecord/pubsub/guild.py b/litecord/pubsub/guild.py index 0085acd..73e363e 100644 --- a/litecord/pubsub/guild.py +++ b/litecord/pubsub/guild.py @@ -96,7 +96,7 @@ class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]): continue try: - await state.ws.dispatch(*event) + await state.dispatch(*event) except Exception: log.exception("error while dispatching to {}", state.session_id) continue diff --git a/litecord/pubsub/lazy_guild.py b/litecord/pubsub/lazy_guild.py index ae6b865..7c87b17 100644 --- a/litecord/pubsub/lazy_guild.py +++ b/litecord/pubsub/lazy_guild.py @@ -596,8 +596,7 @@ class GuildMemberList: if not state: continue - await state.ws.dispatch("GUILD_MEMBER_LIST_UPDATE", payload) - + await state.dispatch("GUILD_MEMBER_LIST_UPDATE", payload) dispatched.append(state.session_id) return dispatched diff --git a/litecord/pubsub/utils.py b/litecord/pubsub/utils.py index 1d6bca6..bb1981b 100644 --- a/litecord/pubsub/utils.py +++ b/litecord/pubsub/utils.py @@ -33,9 +33,8 @@ async def send_event_to_states( event, data = event_data for state in states: try: - if state.ws: - await state.ws.dispatch(event, data) - res.append(state.session_id) + await state.dispatch(event, data) + res.append(state.session_id) except Exception: log.exception("error while dispatching")