move dispatching logic to GatewayState

this makes gateway states have correct data for resuming
This commit is contained in:
Luna 2021-08-30 00:09:34 -03:00
parent d16a893d84
commit f05b807f60
7 changed files with 54 additions and 31 deletions

View File

@ -338,7 +338,4 @@ async def add_member(guild_id: int, user_id: int, *, basic=False):
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
for state in states:
try:
await state.ws.dispatch("GUILD_CREATE", guild)
except Exception:
log.exception("failed to dispatch to session_id={!r}", state.session_id)
await state.dispatch("GUILD_CREATE", guild)

View File

@ -19,10 +19,16 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import hashlib
import os
from typing import Optional, Any
import websockets
from logbook import Logger
from typing import Optional
from litecord.presence import BasePresence
from litecord.enums import Intents
from .opcodes import OP
log = Logger(__name__)
def gen_session_id() -> str:
@ -102,3 +108,30 @@ class GatewayState:
def __repr__(self):
return f"GatewayState<seq={self.seq} shard={self.current_shard},{self.shard_count} uid={self.user_id}>"
async def dispatch(self, event_type: str, event_data: Any) -> None:
"""Dispatch an event to the underlying websocket.
Stores the event in the state's payload store for resuming.
"""
self.seq += 1
payload = {
"op": OP.DISPATCH,
"t": event_type.upper(),
"s": self.seq,
"d": event_data,
}
self.store[self.seq] = payload
log.debug("dispatching event {!r} to session {}", payload["t"], self.session_id)
try:
await self.ws.send(payload)
except websockets.exceptions.ConnectionClosed as exc:
log.warning(
"Failed to dispatch {!r} to session id {}: {!r}",
payload["t"],
self.session_id,
exc,
)

View File

@ -392,11 +392,13 @@ class GatewayWebsocket:
self._hb_start(interval)
async def dispatch(self, event: str, data: Any):
"""Dispatch an event to the websocket."""
assert self.state is not None
self.state.seq += 1
async def dispatch_raw(self, event: str, data: Any):
"""Dispatch an event to the websocket, bypassing the gateway state.
Only use this function for events related to connection state,
such as READY and RESUMED, or events that are replies to
messages in the websocket.
"""
payload = {
"op": OP.DISPATCH,
"t": event.upper(),
@ -404,8 +406,6 @@ class GatewayWebsocket:
"d": data,
}
self.state.store[self.state.seq] = payload
log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id)
try:
@ -451,7 +451,7 @@ class GatewayWebsocket:
if guild is None:
continue
await self.dispatch("GUILD_CREATE", guild)
await self.dispatch_raw("GUILD_CREATE", guild)
async def _user_ready(self, *, settings=None) -> dict:
"""Fetch information about users in the READY packet.
@ -546,9 +546,9 @@ class GatewayWebsocket:
for guild in full_ready_data["guilds"]:
guild["members"] = []
await self.dispatch("READY", full_ready_data)
await self.dispatch_raw("READY", full_ready_data)
if self.ws_properties.version > 6:
await self.dispatch("READY_SUPPLEMENTAL", ready_supplemental)
await self.dispatch_raw("READY_SUPPLEMENTAL", ready_supplemental)
app.sched.spawn(self._guild_dispatch(guilds))
async def _check_shards(self, shard, user_id):
@ -959,9 +959,9 @@ class GatewayWebsocket:
return
if presences:
await self.dispatch("PRESENCE_REPLACE", presences)
await self.dispatch_raw("PRESENCE_REPLACE", presences)
await self.dispatch("RESUMED", {})
await self.dispatch_raw("RESUMED", {})
async def handle_6(self, payload: Dict[str, Any]):
"""Handle OP 6 Resume."""
@ -1044,7 +1044,7 @@ class GatewayWebsocket:
presences = await self.presence.guild_presences(mids, guild_id)
body["presences"] = presences
await self.dispatch("GUILD_MEMBERS_CHUNK", body)
await self.dispatch_raw("GUILD_MEMBERS_CHUNK", body)
async def handle_8(self, payload: Dict):
"""Handle OP 8 Request Guild Members."""
@ -1088,7 +1088,7 @@ class GatewayWebsocket:
log.debug(f"Syncing guild {guild_id} with {len(member_ids)} members")
presences = await self.presence.guild_presences(member_ids, guild_id)
await self.dispatch(
await self.dispatch_raw(
"GUILD_SYNC",
{"id": str(guild_id), "presences": presences, "members": members},
)
@ -1228,7 +1228,7 @@ class GatewayWebsocket:
data = payload["d"]
# stubbed
await self.dispatch(
await self.dispatch_raw(
"GUILD_APPLICATION_COMMANDS_UPDATE",
{
"updated_at": 1630271377245,

View File

@ -78,12 +78,7 @@ class ChannelDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
new_data = gdm_recipient_view(event_data, state.user_id)
correct_event = (event_type, new_data)
try:
await state.ws.dispatch(*correct_event)
except Exception:
log.exception("error while dispatching to {}", state.session_id)
continue
await state.dispatch(*correct_event)
sessions.append(session_id)
log.info(

View File

@ -96,7 +96,7 @@ class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
continue
try:
await state.ws.dispatch(*event)
await state.dispatch(*event)
except Exception:
log.exception("error while dispatching to {}", state.session_id)
continue

View File

@ -596,8 +596,7 @@ class GuildMemberList:
if not state:
continue
await state.ws.dispatch("GUILD_MEMBER_LIST_UPDATE", payload)
await state.dispatch("GUILD_MEMBER_LIST_UPDATE", payload)
dispatched.append(state.session_id)
return dispatched

View File

@ -33,9 +33,8 @@ async def send_event_to_states(
event, data = event_data
for state in states:
try:
if state.ws:
await state.ws.dispatch(event, data)
res.append(state.session_id)
await state.dispatch(event, data)
res.append(state.session_id)
except Exception:
log.exception("error while dispatching")