diff --git a/litecord/common/users.py b/litecord/common/users.py index fa9b825..1e5bdf4 100644 --- a/litecord/common/users.py +++ b/litecord/common/users.py @@ -268,7 +268,7 @@ async def user_disconnect(user_id: int): for state in user_states: # make it unable to resume - app.state_manager.remove(state) + app.state_manager.remove(state.session_id, user_id=user_id) if not state.ws: continue diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index 9239159..08a4114 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -19,7 +19,7 @@ along with this program. If not, see . import asyncio -from typing import List +from typing import List, Optional from collections import defaultdict from quart import current_app as app @@ -121,21 +121,20 @@ class StateManager: """Fetch a single state given the Session ID.""" return self.states_raw[session_id] - def remove(self, state): + def remove(self, session_id: str, *, user_id: Optional[int] = None): """Remove a state from the registry""" - if not state: - return - try: - self.states_raw.pop(state.session_id) + state = self.states_raw.pop(session_id) + user_id = state.user_id except KeyError: pass - try: - log.debug("removing state: {!r}", state) - self.states[state.user_id].pop(state.session_id) - except KeyError: - pass + if user_id is not None: + try: + log.debug("removing state: {!r}", state) + self.states[state.user_id].pop(session_id) + except KeyError: + pass def fetch_states(self, user_id: int, guild_id: int) -> List[GatewayState]: """Fetch all states that are tied to a guild.""" diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 6ef17b7..fffc37f 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -770,10 +770,11 @@ class GatewayWebsocket: # since the state will be removed from # the manager, it will become unreachable # when trying to resume. - self.app.state_manager.remove(self.state) + self.app.state_manager.remove(self.state.user_id) async def _resume(self, replay_seqs: Iterable): - presences = [] + assert self.state is not None + presences: List[dict] = [] try: for seq in replay_seqs: