diff --git a/litecord/gateway/state_manager.py b/litecord/gateway/state_manager.py index e393a79..5185b1d 100644 --- a/litecord/gateway/state_manager.py +++ b/litecord/gateway/state_manager.py @@ -1,18 +1,68 @@ +import asyncio + from typing import List, Dict, Any from collections import defaultdict +from websockets.exceptions import ConnectionClosed from logbook import Logger -from .state import GatewayState +from litecord.gateway.state import GatewayState +from litecord.gateway.opcodes import OP log = Logger(__name__) +class ManagerClose(Exception): + pass + + +class StateDictWrapper: + """Wrap a mapping so that any kind of access to the mapping while the + state manager is closed raises a ManagerClose error""" + def __init__(self, state_manager, mapping): + self.state_manager = state_manager + self._map = mapping + + def _check_closed(self): + if self.state_manager.closed: + raise ManagerClose() + + def __getitem__(self, key): + self._check_closed() + return self._map[key] + + def __delitem__(self, key): + self._check_closed() + del self._map[key] + + def __setitem__(self, key, value): + if not self.state_manager.accept_new: + raise ManagerClose() + + self._check_closed() + self._map[key] = value + + def __iter__(self): + return self._map.__iter__() + + def pop(self, key): + return self._map.pop(key) + + def values(self): + return self._map.values() + + class StateManager: """Manager for gateway state information.""" def __init__(self): + #: closed flag + self.closed = False + + #: accept new states? + self.accept_new = True + # { # user_id: { # session_id: GatewayState, @@ -20,10 +70,10 @@ class StateManager: # }, # user_id_2: {}, ... # } - self.states = defaultdict(dict) + self.states = StateDictWrapper(self, defaultdict(dict)) #: raw mapping from session ids to GatewayState - self.states_raw = {} + self.states_raw = StateDictWrapper(self, {}) def insert(self, state: GatewayState): """Insert a new state object.""" @@ -113,3 +163,54 @@ class StateManager: states.extend(member_states) return states + + async def shutdown_single(self, state: GatewayState): + """Send OP Reconnect to a single connection.""" + websocket = state.ws + + await websocket.send({ + 'op': OP.RECONNECT + }) + + # wait 200ms + # so that the client has time to process + # our payload then close the connection + await asyncio.sleep(0.2) + + try: + # try to close the connection ourselves + await websocket.ws.close( + code=4000, + reason='litecord shutting down' + ) + except ConnectionClosed: + log.info('client {} already closed', state) + + def gen_close_tasks(self): + """Generate the tasks that will order the clients + to reconnect. + + This is required to be ran before :meth:`StateManager.close`, + since this function doesn't wait for the tasks to complete. + """ + + self.accept_new = False + + #: store the shutdown tasks + tasks = [] + + for state in self.states_raw.values(): + if not state.ws: + continue + + tasks.append( + self.shutdown_single(state) + ) + + log.info('made {} shutdown tasks', len(tasks)) + + return tasks + + def close(self): + """Close the state manager.""" + self.closed = True diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 90bbc07..0d54ad3 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -753,6 +753,15 @@ class GatewayWebsocket: async def listen_messages(self): """Listen for messages coming in from the websocket.""" + + # close anyone trying to login while the + # server is shutting down + if self.ext.state_manager.closed: + raise WebsocketClose(4000, 'state manager closed') + + if not self.ext.state_manager.accept_new: + raise WebsocketClose(4000, 'state manager closed for new') + while True: message = await self.ws.recv() if len(message) > 4096: @@ -762,6 +771,9 @@ class GatewayWebsocket: await self.process_message(payload) def _cleanup(self): + for task in self.wsp.tasks.values(): + task.cancel() + if self.state: self.ext.state_manager.remove(self.state) self.state.ws = None diff --git a/run.py b/run.py index 27753dd..6f9f748 100644 --- a/run.py +++ b/run.py @@ -131,6 +131,8 @@ async def app_before_serving(): async def _wrapper(ws, url): # We wrap the main websocket_handler # so we can pass quart's app object. + + # TODO: pass just the app object await websocket_handler((app.db, app.state_manager, app.storage, app.loop, app.dispatcher, app.presence), ws, url) @@ -142,6 +144,15 @@ async def app_before_serving(): @app.after_serving async def app_after_serving(): + """Shutdown tasks for the server.""" + + # first close all clients, then close db + tasks = app.state_manager.gen_close_tasks() + if tasks: + await asyncio.wait(tasks, loop=app.loop) + + app.state_manager.close() + log.info('closing db') await app.db.close()