diff --git a/litecord/common/users.py b/litecord/common/users.py index fa9b825..1e5bdf4 100644 --- a/litecord/common/users.py +++ b/litecord/common/users.py @@ -268,7 +268,7 @@ async def user_disconnect(user_id: int): for state in user_states: # make it unable to resume - app.state_manager.remove(state) + app.state_manager.remove(state.session_id, user_id=user_id) if not state.ws: continue diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index 6b92eb9..08a4114 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -19,7 +19,7 @@ along with this program. If not, see . import asyncio -from typing import List +from typing import List, Optional from collections import defaultdict from quart import current_app as app @@ -96,6 +96,8 @@ class StateManager: #: raw mapping from session ids to GatewayState self.states_raw = StateDictWrapper(self, {}) + self.tasks = {} + def insert(self, state: GatewayState): """Insert a new state object.""" user_states = self.states[state.user_id] @@ -119,21 +121,20 @@ class StateManager: """Fetch a single state given the Session ID.""" return self.states_raw[session_id] - def remove(self, state): + def remove(self, session_id: str, *, user_id: Optional[int] = None): """Remove a state from the registry""" - if not state: - return - try: - self.states_raw.pop(state.session_id) + state = self.states_raw.pop(session_id) + user_id = state.user_id except KeyError: pass - try: - log.debug("removing state: {!r}", state) - self.states[state.user_id].pop(state.session_id) - except KeyError: - pass + if user_id is not None: + try: + log.debug("removing state: {!r}", state) + self.states[state.user_id].pop(session_id) + except KeyError: + pass def fetch_states(self, user_id: int, guild_id: int) -> List[GatewayState]: """Fetch all states that are tied to a guild.""" @@ -188,14 +189,14 @@ class StateManager: """Send OP Reconnect to a single connection.""" websocket = state.ws - await websocket.send({"op": OP.RECONNECT}) - - # wait 200ms - # so that the client has time to process - # our payload then close the connection - await asyncio.sleep(0.2) - try: + await websocket.send({"op": OP.RECONNECT}) + + # wait 200ms + # so that the client has time to process + # our payload then close the connection + await asyncio.sleep(0.2) + # try to close the connection ourselves await websocket.ws.close(code=4000, reason="litecord shutting down") except ConnectionClosed: @@ -239,3 +240,21 @@ class StateManager: # DMs and GDMs use all user states return self.user_states(user_id) + + async def _future_cleanup(self, state: GatewayState): + await asyncio.sleep(30) + self.remove(state) + state.ws.state = None + state.ws = None + + async def schedule_deletion(self, state: GatewayState): + task = app.loop.create_task(self._future_cleanup(state)) + self.tasks[state.session_id] = task + + async def unschedule_deletion(self, state: GatewayState): + try: + task = self.tasks.pop(state.session_id) + except KeyError: + return + + task.cancel() diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index a98752f..fffc37f 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -267,9 +267,15 @@ class GatewayWebsocket: log.debug("sending payload {!r} sid {}", event.upper(), self.state.session_id) - await self.send(payload) + try: + await self.send(payload) + except websockets.exceptions.ConnectionClosed: + log.warning( + "Failed to dispatch {!r} to {}", event.upper, self.state.session_id + ) async def _make_guild_list(self) -> List[Dict[str, Any]]: + assert self.state is not None user_id = self.state.user_id guild_ids = await self._guild_ids() @@ -764,10 +770,11 @@ class GatewayWebsocket: # since the state will be removed from # the manager, it will become unreachable # when trying to resume. - self.app.state_manager.remove(self.state) + self.app.state_manager.remove(self.state.user_id) async def _resume(self, replay_seqs: Iterable): - presences = [] + assert self.state is not None + presences: List[dict] = [] try: for seq in replay_seqs: @@ -824,6 +831,7 @@ class GatewayWebsocket: return await self.invalidate_session(False) # relink this connection + await self.app.state_manager.unschedule_deletion(state) self.state = state state.ws = self @@ -1085,8 +1093,8 @@ class GatewayWebsocket: task.cancel() if self.state: - self.app.state_manager.remove(self.state) self.state.ws = None + self.app.state_manager.schedule_deletion(self.state) self.state = None async def _check_conns(self, user_id): diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 6402f21..ed49838 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -176,3 +176,79 @@ async def test_etf(test_cli): assert hello["op"] == OP.HELLO finally: await _close(conn) + + +@pytest.mark.asyncio +async def test_resume(test_cli_user): + conn = await gw_start(test_cli_user.cli) + + # get the hello frame but ignore it + await _json(conn) + + await _json_send( + conn, {"op": OP.IDENTIFY, "d": {"token": test_cli_user.user["token"]}} + ) + + try: + ready = await _json(conn) + assert isinstance(ready, dict) + assert ready["op"] == OP.DISPATCH + assert ready["t"] == "READY" + + data = ready["d"] + assert isinstance(data, dict) + + assert isinstance(data["session_id"], str) + sess_id: str = data["session_id"] + finally: + await _close(conn) + + # try to resume + conn = await gw_start(test_cli_user.cli) + _ = await _json(conn) + + await _json_send( + conn, + { + "op": OP.RESUME, + "d": { + "token": test_cli_user.user["token"], + "session_id": sess_id, + "seq": 0, + }, + }, + ) + + msg = await _json(conn) + assert isinstance(msg, dict) + assert isinstance(msg["op"], int) + assert msg["op"] == OP.DISPATCH + assert isinstance(msg["t"], str) + assert msg["t"] in ("RESUMED", "PRESENCE_REPLACE") + + # close again, and retry again, but this time by removing the state + # and asserting the session won't be resumed. + await _close(conn) + + conn = await gw_start(test_cli_user.cli) + _ = await _json(conn) + + async with test_cli_user.app.app_context(): + test_cli_user.app.state_manager.remove(sess_id) + + await _json_send( + conn, + { + "op": OP.RESUME, + "d": { + "token": test_cli_user.user["token"], + "session_id": sess_id, + "seq": 0, + }, + }, + ) + + msg = await _json(conn) + assert isinstance(msg, dict) + assert isinstance(msg["op"], int) + assert msg["op"] == OP.INVALID_SESSION