gateway.state_manager: send OP 7 Reconnect to clients

- gateway.websocket: check StateManager flags on new connections
 - gateway.websocket: cancel all tasks on GatewayWebsocket.wsp.tasks
 - run: call StateManager.gen_close_tasks() and StateManager.close() on
    app shutdown
This commit is contained in:
Luna Mendes 2018-11-03 21:58:51 -03:00
parent afb429ec77
commit 69fbd9c117
3 changed files with 127 additions and 3 deletions

View File

@ -1,18 +1,68 @@
import asyncio
from typing import List, Dict, Any from typing import List, Dict, Any
from collections import defaultdict from collections import defaultdict
from websockets.exceptions import ConnectionClosed
from logbook import Logger from logbook import Logger
from .state import GatewayState from litecord.gateway.state import GatewayState
from litecord.gateway.opcodes import OP
log = Logger(__name__) log = Logger(__name__)
class ManagerClose(Exception):
pass
class StateDictWrapper:
"""Wrap a mapping so that any kind of access to the mapping while the
state manager is closed raises a ManagerClose error"""
def __init__(self, state_manager, mapping):
self.state_manager = state_manager
self._map = mapping
def _check_closed(self):
if self.state_manager.closed:
raise ManagerClose()
def __getitem__(self, key):
self._check_closed()
return self._map[key]
def __delitem__(self, key):
self._check_closed()
del self._map[key]
def __setitem__(self, key, value):
if not self.state_manager.accept_new:
raise ManagerClose()
self._check_closed()
self._map[key] = value
def __iter__(self):
return self._map.__iter__()
def pop(self, key):
return self._map.pop(key)
def values(self):
return self._map.values()
class StateManager: class StateManager:
"""Manager for gateway state information.""" """Manager for gateway state information."""
def __init__(self): def __init__(self):
#: closed flag
self.closed = False
#: accept new states?
self.accept_new = True
# { # {
# user_id: { # user_id: {
# session_id: GatewayState, # session_id: GatewayState,
@ -20,10 +70,10 @@ class StateManager:
# }, # },
# user_id_2: {}, ... # user_id_2: {}, ...
# } # }
self.states = defaultdict(dict) self.states = StateDictWrapper(self, defaultdict(dict))
#: raw mapping from session ids to GatewayState #: raw mapping from session ids to GatewayState
self.states_raw = {} self.states_raw = StateDictWrapper(self, {})
def insert(self, state: GatewayState): def insert(self, state: GatewayState):
"""Insert a new state object.""" """Insert a new state object."""
@ -113,3 +163,54 @@ class StateManager:
states.extend(member_states) states.extend(member_states)
return states return states
async def shutdown_single(self, state: GatewayState):
"""Send OP Reconnect to a single connection."""
websocket = state.ws
await websocket.send({
'op': OP.RECONNECT
})
# wait 200ms
# so that the client has time to process
# our payload then close the connection
await asyncio.sleep(0.2)
try:
# try to close the connection ourselves
await websocket.ws.close(
code=4000,
reason='litecord shutting down'
)
except ConnectionClosed:
log.info('client {} already closed', state)
def gen_close_tasks(self):
"""Generate the tasks that will order the clients
to reconnect.
This is required to be ran before :meth:`StateManager.close`,
since this function doesn't wait for the tasks to complete.
"""
self.accept_new = False
#: store the shutdown tasks
tasks = []
for state in self.states_raw.values():
if not state.ws:
continue
tasks.append(
self.shutdown_single(state)
)
log.info('made {} shutdown tasks', len(tasks))
return tasks
def close(self):
"""Close the state manager."""
self.closed = True

View File

@ -753,6 +753,15 @@ class GatewayWebsocket:
async def listen_messages(self): async def listen_messages(self):
"""Listen for messages coming in from the websocket.""" """Listen for messages coming in from the websocket."""
# close anyone trying to login while the
# server is shutting down
if self.ext.state_manager.closed:
raise WebsocketClose(4000, 'state manager closed')
if not self.ext.state_manager.accept_new:
raise WebsocketClose(4000, 'state manager closed for new')
while True: while True:
message = await self.ws.recv() message = await self.ws.recv()
if len(message) > 4096: if len(message) > 4096:
@ -762,6 +771,9 @@ class GatewayWebsocket:
await self.process_message(payload) await self.process_message(payload)
def _cleanup(self): def _cleanup(self):
for task in self.wsp.tasks.values():
task.cancel()
if self.state: if self.state:
self.ext.state_manager.remove(self.state) self.ext.state_manager.remove(self.state)
self.state.ws = None self.state.ws = None

11
run.py
View File

@ -131,6 +131,8 @@ async def app_before_serving():
async def _wrapper(ws, url): async def _wrapper(ws, url):
# We wrap the main websocket_handler # We wrap the main websocket_handler
# so we can pass quart's app object. # so we can pass quart's app object.
# TODO: pass just the app object
await websocket_handler((app.db, app.state_manager, app.storage, await websocket_handler((app.db, app.state_manager, app.storage,
app.loop, app.dispatcher, app.presence), app.loop, app.dispatcher, app.presence),
ws, url) ws, url)
@ -142,6 +144,15 @@ async def app_before_serving():
@app.after_serving @app.after_serving
async def app_after_serving(): async def app_after_serving():
"""Shutdown tasks for the server."""
# first close all clients, then close db
tasks = app.state_manager.gen_close_tasks()
if tasks:
await asyncio.wait(tasks, loop=app.loop)
app.state_manager.close()
log.info('closing db') log.info('closing db')
await app.db.close() await app.db.close()