move dispatching logic to GatewayState

this makes gateway states have correct data for resuming
This commit is contained in:
Luna 2021-08-30 00:09:34 -03:00
parent d16a893d84
commit f05b807f60
7 changed files with 54 additions and 31 deletions

View File

@ -338,7 +338,4 @@ async def add_member(guild_id: int, user_id: int, *, basic=False):
guild = await app.storage.get_guild_full(guild_id, user_id, 250) guild = await app.storage.get_guild_full(guild_id, user_id, 250)
for state in states: for state in states:
try: await state.dispatch("GUILD_CREATE", guild)
await state.ws.dispatch("GUILD_CREATE", guild)
except Exception:
log.exception("failed to dispatch to session_id={!r}", state.session_id)

View File

@ -19,10 +19,16 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import hashlib import hashlib
import os import os
from typing import Optional, Any
import websockets
from logbook import Logger
from typing import Optional
from litecord.presence import BasePresence from litecord.presence import BasePresence
from litecord.enums import Intents from litecord.enums import Intents
from .opcodes import OP
log = Logger(__name__)
def gen_session_id() -> str: def gen_session_id() -> str:
@ -102,3 +108,30 @@ class GatewayState:
def __repr__(self): def __repr__(self):
return f"GatewayState<seq={self.seq} shard={self.current_shard},{self.shard_count} uid={self.user_id}>" return f"GatewayState<seq={self.seq} shard={self.current_shard},{self.shard_count} uid={self.user_id}>"
async def dispatch(self, event_type: str, event_data: Any) -> None:
"""Dispatch an event to the underlying websocket.
Stores the event in the state's payload store for resuming.
"""
self.seq += 1
payload = {
"op": OP.DISPATCH,
"t": event_type.upper(),
"s": self.seq,
"d": event_data,
}
self.store[self.seq] = payload
log.debug("dispatching event {!r} to session {}", payload["t"], self.session_id)
try:
await self.ws.send(payload)
except websockets.exceptions.ConnectionClosed as exc:
log.warning(
"Failed to dispatch {!r} to session id {}: {!r}",
payload["t"],
self.session_id,
exc,
)

View File

@ -392,11 +392,13 @@ class GatewayWebsocket:
self._hb_start(interval) self._hb_start(interval)
async def dispatch(self, event: str, data: Any): async def dispatch_raw(self, event: str, data: Any):
"""Dispatch an event to the websocket.""" """Dispatch an event to the websocket, bypassing the gateway state.
assert self.state is not None
self.state.seq += 1
Only use this function for events related to connection state,
such as READY and RESUMED, or events that are replies to
messages in the websocket.
"""
payload = { payload = {
"op": OP.DISPATCH, "op": OP.DISPATCH,
"t": event.upper(), "t": event.upper(),
@ -404,8 +406,6 @@ class GatewayWebsocket:
"d": data, "d": data,
} }
self.state.store[self.state.seq] = payload
log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id) log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id)
try: try:
@ -451,7 +451,7 @@ class GatewayWebsocket:
if guild is None: if guild is None:
continue continue
await self.dispatch("GUILD_CREATE", guild) await self.dispatch_raw("GUILD_CREATE", guild)
async def _user_ready(self, *, settings=None) -> dict: async def _user_ready(self, *, settings=None) -> dict:
"""Fetch information about users in the READY packet. """Fetch information about users in the READY packet.
@ -546,9 +546,9 @@ class GatewayWebsocket:
for guild in full_ready_data["guilds"]: for guild in full_ready_data["guilds"]:
guild["members"] = [] guild["members"] = []
await self.dispatch("READY", full_ready_data) await self.dispatch_raw("READY", full_ready_data)
if self.ws_properties.version > 6: if self.ws_properties.version > 6:
await self.dispatch("READY_SUPPLEMENTAL", ready_supplemental) await self.dispatch_raw("READY_SUPPLEMENTAL", ready_supplemental)
app.sched.spawn(self._guild_dispatch(guilds)) app.sched.spawn(self._guild_dispatch(guilds))
async def _check_shards(self, shard, user_id): async def _check_shards(self, shard, user_id):
@ -959,9 +959,9 @@ class GatewayWebsocket:
return return
if presences: if presences:
await self.dispatch("PRESENCE_REPLACE", presences) await self.dispatch_raw("PRESENCE_REPLACE", presences)
await self.dispatch("RESUMED", {}) await self.dispatch_raw("RESUMED", {})
async def handle_6(self, payload: Dict[str, Any]): async def handle_6(self, payload: Dict[str, Any]):
"""Handle OP 6 Resume.""" """Handle OP 6 Resume."""
@ -1044,7 +1044,7 @@ class GatewayWebsocket:
presences = await self.presence.guild_presences(mids, guild_id) presences = await self.presence.guild_presences(mids, guild_id)
body["presences"] = presences body["presences"] = presences
await self.dispatch("GUILD_MEMBERS_CHUNK", body) await self.dispatch_raw("GUILD_MEMBERS_CHUNK", body)
async def handle_8(self, payload: Dict): async def handle_8(self, payload: Dict):
"""Handle OP 8 Request Guild Members.""" """Handle OP 8 Request Guild Members."""
@ -1088,7 +1088,7 @@ class GatewayWebsocket:
log.debug(f"Syncing guild {guild_id} with {len(member_ids)} members") log.debug(f"Syncing guild {guild_id} with {len(member_ids)} members")
presences = await self.presence.guild_presences(member_ids, guild_id) presences = await self.presence.guild_presences(member_ids, guild_id)
await self.dispatch( await self.dispatch_raw(
"GUILD_SYNC", "GUILD_SYNC",
{"id": str(guild_id), "presences": presences, "members": members}, {"id": str(guild_id), "presences": presences, "members": members},
) )
@ -1228,7 +1228,7 @@ class GatewayWebsocket:
data = payload["d"] data = payload["d"]
# stubbed # stubbed
await self.dispatch( await self.dispatch_raw(
"GUILD_APPLICATION_COMMANDS_UPDATE", "GUILD_APPLICATION_COMMANDS_UPDATE",
{ {
"updated_at": 1630271377245, "updated_at": 1630271377245,

View File

@ -78,12 +78,7 @@ class ChannelDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
new_data = gdm_recipient_view(event_data, state.user_id) new_data = gdm_recipient_view(event_data, state.user_id)
correct_event = (event_type, new_data) correct_event = (event_type, new_data)
try: await state.dispatch(*correct_event)
await state.ws.dispatch(*correct_event)
except Exception:
log.exception("error while dispatching to {}", state.session_id)
continue
sessions.append(session_id) sessions.append(session_id)
log.info( log.info(

View File

@ -96,7 +96,7 @@ class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
continue continue
try: try:
await state.ws.dispatch(*event) await state.dispatch(*event)
except Exception: except Exception:
log.exception("error while dispatching to {}", state.session_id) log.exception("error while dispatching to {}", state.session_id)
continue continue

View File

@ -596,8 +596,7 @@ class GuildMemberList:
if not state: if not state:
continue continue
await state.ws.dispatch("GUILD_MEMBER_LIST_UPDATE", payload) await state.dispatch("GUILD_MEMBER_LIST_UPDATE", payload)
dispatched.append(state.session_id) dispatched.append(state.session_id)
return dispatched return dispatched

View File

@ -33,8 +33,7 @@ async def send_event_to_states(
event, data = event_data event, data = event_data
for state in states: for state in states:
try: try:
if state.ws: await state.dispatch(event, data)
await state.ws.dispatch(event, data)
res.append(state.session_id) res.append(state.session_id)
except Exception: except Exception:
log.exception("error while dispatching") log.exception("error while dispatching")