Change remove() declaration

Allow Session IDs to be passed, instead of full state objects.
This commit is contained in:
Luna 2020-04-05 15:08:18 -03:00
parent 8e6bbdbe19
commit 7b6b696717
3 changed files with 14 additions and 14 deletions

View File

@ -268,7 +268,7 @@ async def user_disconnect(user_id: int):
for state in user_states:
# make it unable to resume
app.state_manager.remove(state)
app.state_manager.remove(state.session_id, user_id=user_id)
if not state.ws:
continue

View File

@ -19,7 +19,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
from typing import List
from typing import List, Optional
from collections import defaultdict
from quart import current_app as app
@ -121,21 +121,20 @@ class StateManager:
"""Fetch a single state given the Session ID."""
return self.states_raw[session_id]
def remove(self, state):
def remove(self, session_id: str, *, user_id: Optional[int] = None):
"""Remove a state from the registry"""
if not state:
return
try:
self.states_raw.pop(state.session_id)
state = self.states_raw.pop(session_id)
user_id = state.user_id
except KeyError:
pass
try:
log.debug("removing state: {!r}", state)
self.states[state.user_id].pop(state.session_id)
except KeyError:
pass
if user_id is not None:
try:
log.debug("removing state: {!r}", state)
self.states[state.user_id].pop(session_id)
except KeyError:
pass
def fetch_states(self, user_id: int, guild_id: int) -> List[GatewayState]:
"""Fetch all states that are tied to a guild."""

View File

@ -770,10 +770,11 @@ class GatewayWebsocket:
# since the state will be removed from
# the manager, it will become unreachable
# when trying to resume.
self.app.state_manager.remove(self.state)
self.app.state_manager.remove(self.state.user_id)
async def _resume(self, replay_seqs: Iterable):
presences = []
assert self.state is not None
presences: List[dict] = []
try:
for seq in replay_seqs: