mirror of https://gitlab.com/litecord/litecord.git
move dispatching logic to GatewayState
this makes gateway states have correct data for resuming
This commit is contained in:
parent
d16a893d84
commit
f05b807f60
|
|
@ -338,7 +338,4 @@ async def add_member(guild_id: int, user_id: int, *, basic=False):
|
|||
|
||||
guild = await app.storage.get_guild_full(guild_id, user_id, 250)
|
||||
for state in states:
|
||||
try:
|
||||
await state.ws.dispatch("GUILD_CREATE", guild)
|
||||
except Exception:
|
||||
log.exception("failed to dispatch to session_id={!r}", state.session_id)
|
||||
await state.dispatch("GUILD_CREATE", guild)
|
||||
|
|
|
|||
|
|
@ -19,10 +19,16 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Optional, Any
|
||||
|
||||
import websockets
|
||||
from logbook import Logger
|
||||
|
||||
from typing import Optional
|
||||
from litecord.presence import BasePresence
|
||||
from litecord.enums import Intents
|
||||
from .opcodes import OP
|
||||
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
def gen_session_id() -> str:
|
||||
|
|
@ -102,3 +108,30 @@ class GatewayState:
|
|||
|
||||
def __repr__(self):
|
||||
return f"GatewayState<seq={self.seq} shard={self.current_shard},{self.shard_count} uid={self.user_id}>"
|
||||
|
||||
async def dispatch(self, event_type: str, event_data: Any) -> None:
|
||||
"""Dispatch an event to the underlying websocket.
|
||||
|
||||
Stores the event in the state's payload store for resuming.
|
||||
"""
|
||||
self.seq += 1
|
||||
payload = {
|
||||
"op": OP.DISPATCH,
|
||||
"t": event_type.upper(),
|
||||
"s": self.seq,
|
||||
"d": event_data,
|
||||
}
|
||||
|
||||
self.store[self.seq] = payload
|
||||
|
||||
log.debug("dispatching event {!r} to session {}", payload["t"], self.session_id)
|
||||
|
||||
try:
|
||||
await self.ws.send(payload)
|
||||
except websockets.exceptions.ConnectionClosed as exc:
|
||||
log.warning(
|
||||
"Failed to dispatch {!r} to session id {}: {!r}",
|
||||
payload["t"],
|
||||
self.session_id,
|
||||
exc,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -392,11 +392,13 @@ class GatewayWebsocket:
|
|||
|
||||
self._hb_start(interval)
|
||||
|
||||
async def dispatch(self, event: str, data: Any):
|
||||
"""Dispatch an event to the websocket."""
|
||||
assert self.state is not None
|
||||
self.state.seq += 1
|
||||
async def dispatch_raw(self, event: str, data: Any):
|
||||
"""Dispatch an event to the websocket, bypassing the gateway state.
|
||||
|
||||
Only use this function for events related to connection state,
|
||||
such as READY and RESUMED, or events that are replies to
|
||||
messages in the websocket.
|
||||
"""
|
||||
payload = {
|
||||
"op": OP.DISPATCH,
|
||||
"t": event.upper(),
|
||||
|
|
@ -404,8 +406,6 @@ class GatewayWebsocket:
|
|||
"d": data,
|
||||
}
|
||||
|
||||
self.state.store[self.state.seq] = payload
|
||||
|
||||
log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id)
|
||||
|
||||
try:
|
||||
|
|
@ -451,7 +451,7 @@ class GatewayWebsocket:
|
|||
if guild is None:
|
||||
continue
|
||||
|
||||
await self.dispatch("GUILD_CREATE", guild)
|
||||
await self.dispatch_raw("GUILD_CREATE", guild)
|
||||
|
||||
async def _user_ready(self, *, settings=None) -> dict:
|
||||
"""Fetch information about users in the READY packet.
|
||||
|
|
@ -546,9 +546,9 @@ class GatewayWebsocket:
|
|||
for guild in full_ready_data["guilds"]:
|
||||
guild["members"] = []
|
||||
|
||||
await self.dispatch("READY", full_ready_data)
|
||||
await self.dispatch_raw("READY", full_ready_data)
|
||||
if self.ws_properties.version > 6:
|
||||
await self.dispatch("READY_SUPPLEMENTAL", ready_supplemental)
|
||||
await self.dispatch_raw("READY_SUPPLEMENTAL", ready_supplemental)
|
||||
app.sched.spawn(self._guild_dispatch(guilds))
|
||||
|
||||
async def _check_shards(self, shard, user_id):
|
||||
|
|
@ -959,9 +959,9 @@ class GatewayWebsocket:
|
|||
return
|
||||
|
||||
if presences:
|
||||
await self.dispatch("PRESENCE_REPLACE", presences)
|
||||
await self.dispatch_raw("PRESENCE_REPLACE", presences)
|
||||
|
||||
await self.dispatch("RESUMED", {})
|
||||
await self.dispatch_raw("RESUMED", {})
|
||||
|
||||
async def handle_6(self, payload: Dict[str, Any]):
|
||||
"""Handle OP 6 Resume."""
|
||||
|
|
@ -1044,7 +1044,7 @@ class GatewayWebsocket:
|
|||
presences = await self.presence.guild_presences(mids, guild_id)
|
||||
body["presences"] = presences
|
||||
|
||||
await self.dispatch("GUILD_MEMBERS_CHUNK", body)
|
||||
await self.dispatch_raw("GUILD_MEMBERS_CHUNK", body)
|
||||
|
||||
async def handle_8(self, payload: Dict):
|
||||
"""Handle OP 8 Request Guild Members."""
|
||||
|
|
@ -1088,7 +1088,7 @@ class GatewayWebsocket:
|
|||
log.debug(f"Syncing guild {guild_id} with {len(member_ids)} members")
|
||||
presences = await self.presence.guild_presences(member_ids, guild_id)
|
||||
|
||||
await self.dispatch(
|
||||
await self.dispatch_raw(
|
||||
"GUILD_SYNC",
|
||||
{"id": str(guild_id), "presences": presences, "members": members},
|
||||
)
|
||||
|
|
@ -1228,7 +1228,7 @@ class GatewayWebsocket:
|
|||
data = payload["d"]
|
||||
|
||||
# stubbed
|
||||
await self.dispatch(
|
||||
await self.dispatch_raw(
|
||||
"GUILD_APPLICATION_COMMANDS_UPDATE",
|
||||
{
|
||||
"updated_at": 1630271377245,
|
||||
|
|
|
|||
|
|
@ -78,12 +78,7 @@ class ChannelDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
|
|||
new_data = gdm_recipient_view(event_data, state.user_id)
|
||||
correct_event = (event_type, new_data)
|
||||
|
||||
try:
|
||||
await state.ws.dispatch(*correct_event)
|
||||
except Exception:
|
||||
log.exception("error while dispatching to {}", state.session_id)
|
||||
continue
|
||||
|
||||
await state.dispatch(*correct_event)
|
||||
sessions.append(session_id)
|
||||
|
||||
log.info(
|
||||
|
|
|
|||
|
|
@ -96,7 +96,7 @@ class GuildDispatcher(DispatcherWithState[int, str, GatewayEvent, List[str]]):
|
|||
continue
|
||||
|
||||
try:
|
||||
await state.ws.dispatch(*event)
|
||||
await state.dispatch(*event)
|
||||
except Exception:
|
||||
log.exception("error while dispatching to {}", state.session_id)
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -596,8 +596,7 @@ class GuildMemberList:
|
|||
if not state:
|
||||
continue
|
||||
|
||||
await state.ws.dispatch("GUILD_MEMBER_LIST_UPDATE", payload)
|
||||
|
||||
await state.dispatch("GUILD_MEMBER_LIST_UPDATE", payload)
|
||||
dispatched.append(state.session_id)
|
||||
|
||||
return dispatched
|
||||
|
|
|
|||
|
|
@ -33,9 +33,8 @@ async def send_event_to_states(
|
|||
event, data = event_data
|
||||
for state in states:
|
||||
try:
|
||||
if state.ws:
|
||||
await state.ws.dispatch(event, data)
|
||||
res.append(state.session_id)
|
||||
await state.dispatch(event, data)
|
||||
res.append(state.session_id)
|
||||
except Exception:
|
||||
log.exception("error while dispatching")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue