litecord/litecord/gateway/state_manager.py

217 lines
5.9 KiB
Python

import asyncio
from typing import List
from collections import defaultdict
from websockets.exceptions import ConnectionClosed
from logbook import Logger
from litecord.gateway.state import GatewayState
from litecord.gateway.opcodes import OP
log = Logger(__name__)
class ManagerClose(Exception):
pass
class StateDictWrapper:
"""Wrap a mapping so that any kind of access to the mapping while the
state manager is closed raises a ManagerClose error"""
def __init__(self, state_manager, mapping):
self.state_manager = state_manager
self._map = mapping
def _check_closed(self):
if self.state_manager.closed:
raise ManagerClose()
def __getitem__(self, key):
self._check_closed()
return self._map[key]
def __delitem__(self, key):
self._check_closed()
del self._map[key]
def __setitem__(self, key, value):
if not self.state_manager.accept_new:
raise ManagerClose()
self._check_closed()
self._map[key] = value
def __iter__(self):
return self._map.__iter__()
def pop(self, key):
return self._map.pop(key)
def values(self):
return self._map.values()
class StateManager:
"""Manager for gateway state information."""
def __init__(self):
#: closed flag
self.closed = False
#: accept new states?
self.accept_new = True
# {
# user_id: {
# session_id: GatewayState,
# session_id_2: GatewayState, ...
# },
# user_id_2: {}, ...
# }
self.states = StateDictWrapper(self, defaultdict(dict))
#: raw mapping from session ids to GatewayState
self.states_raw = StateDictWrapper(self, {})
def insert(self, state: GatewayState):
"""Insert a new state object."""
user_states = self.states[state.user_id]
log.debug('inserting state: {!r}', state)
user_states[state.session_id] = state
self.states_raw[state.session_id] = state
def fetch(self, user_id: int, session_id: str) -> GatewayState:
"""Fetch a state object from the manager.
Raises
------
KeyError
When the user_id or session_id
aren't found in the store.
"""
return self.states[user_id][session_id]
def fetch_raw(self, session_id: str) -> GatewayState:
"""Fetch a single state given the Session ID."""
return self.states_raw[session_id]
def remove(self, state):
"""Remove a state from the registry"""
if not state:
return
try:
self.states_raw.pop(state.session_id)
except KeyError:
pass
try:
log.debug('removing state: {!r}', state)
self.states[state.user_id].pop(state.session_id)
except KeyError:
pass
def fetch_states(self, user_id: int, guild_id: int) -> List[GatewayState]:
"""Fetch all states that are tied to a guild."""
states = []
for state in self.states[user_id].values():
# find out if we are the shard for the guild id
# this works if shard_count == 1 (the default for
# single gw connections) since N % 1 is always 0
shard_id = (guild_id >> 22) % state.shard_count
if shard_id == state.current_shard:
states.append(state)
return states
def user_states(self, user_id: int) -> List[GatewayState]:
"""Fetch all states tied to a single user."""
return list(self.states[user_id].values())
def guild_states(self, member_ids: List[int],
guild_id: int) -> List[GatewayState]:
"""Fetch all possible states about members in a guild."""
states = []
for member_id in member_ids:
member_states = self.fetch_states(member_id, guild_id)
# member_states is empty if the user never logged in
# since server start, so we need to add a dummy state
if not member_states:
dummy_state = GatewayState(
session_id='',
user_id=member_id,
presence={
'afk': False,
'status': 'offline',
'game': None,
'since': 0
}
)
states.append(dummy_state)
continue
# push all available member states to the result
# array
states.extend(member_states)
return states
async def shutdown_single(self, state: GatewayState):
"""Send OP Reconnect to a single connection."""
websocket = state.ws
await websocket.send({
'op': OP.RECONNECT
})
# wait 200ms
# so that the client has time to process
# our payload then close the connection
await asyncio.sleep(0.2)
try:
# try to close the connection ourselves
await websocket.ws.close(
code=4000,
reason='litecord shutting down'
)
except ConnectionClosed:
log.info('client {} already closed', state)
def gen_close_tasks(self):
"""Generate the tasks that will order the clients
to reconnect.
This is required to be ran before :meth:`StateManager.close`,
since this function doesn't wait for the tasks to complete.
"""
self.accept_new = False
#: store the shutdown tasks
tasks = []
for state in self.states_raw.values():
if not state.ws:
continue
tasks.append(
self.shutdown_single(state)
)
log.info('made {} shutdown tasks', len(tasks))
return tasks
def close(self):
"""Close the state manager."""
self.closed = True