mirror of https://gitlab.com/litecord/litecord.git
gateway.state_manager: send OP 7 Reconnect to clients
- gateway.websocket: check StateManager flags on new connections
- gateway.websocket: cancel all tasks on GatewayWebsocket.wsp.tasks
- run: call StateManager.gen_close_tasks() and StateManager.close() on
app shutdown
This commit is contained in:
parent
afb429ec77
commit
69fbd9c117
|
|
@ -1,18 +1,68 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from websockets.exceptions import ConnectionClosed
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
|
|
||||||
from .state import GatewayState
|
from litecord.gateway.state import GatewayState
|
||||||
|
from litecord.gateway.opcodes import OP
|
||||||
|
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ManagerClose(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StateDictWrapper:
|
||||||
|
"""Wrap a mapping so that any kind of access to the mapping while the
|
||||||
|
state manager is closed raises a ManagerClose error"""
|
||||||
|
def __init__(self, state_manager, mapping):
|
||||||
|
self.state_manager = state_manager
|
||||||
|
self._map = mapping
|
||||||
|
|
||||||
|
def _check_closed(self):
|
||||||
|
if self.state_manager.closed:
|
||||||
|
raise ManagerClose()
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
self._check_closed()
|
||||||
|
return self._map[key]
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
self._check_closed()
|
||||||
|
del self._map[key]
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
if not self.state_manager.accept_new:
|
||||||
|
raise ManagerClose()
|
||||||
|
|
||||||
|
self._check_closed()
|
||||||
|
self._map[key] = value
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self._map.__iter__()
|
||||||
|
|
||||||
|
def pop(self, key):
|
||||||
|
return self._map.pop(key)
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return self._map.values()
|
||||||
|
|
||||||
|
|
||||||
class StateManager:
|
class StateManager:
|
||||||
"""Manager for gateway state information."""
|
"""Manager for gateway state information."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
#: closed flag
|
||||||
|
self.closed = False
|
||||||
|
|
||||||
|
#: accept new states?
|
||||||
|
self.accept_new = True
|
||||||
|
|
||||||
# {
|
# {
|
||||||
# user_id: {
|
# user_id: {
|
||||||
# session_id: GatewayState,
|
# session_id: GatewayState,
|
||||||
|
|
@ -20,10 +70,10 @@ class StateManager:
|
||||||
# },
|
# },
|
||||||
# user_id_2: {}, ...
|
# user_id_2: {}, ...
|
||||||
# }
|
# }
|
||||||
self.states = defaultdict(dict)
|
self.states = StateDictWrapper(self, defaultdict(dict))
|
||||||
|
|
||||||
#: raw mapping from session ids to GatewayState
|
#: raw mapping from session ids to GatewayState
|
||||||
self.states_raw = {}
|
self.states_raw = StateDictWrapper(self, {})
|
||||||
|
|
||||||
def insert(self, state: GatewayState):
|
def insert(self, state: GatewayState):
|
||||||
"""Insert a new state object."""
|
"""Insert a new state object."""
|
||||||
|
|
@ -113,3 +163,54 @@ class StateManager:
|
||||||
states.extend(member_states)
|
states.extend(member_states)
|
||||||
|
|
||||||
return states
|
return states
|
||||||
|
|
||||||
|
async def shutdown_single(self, state: GatewayState):
|
||||||
|
"""Send OP Reconnect to a single connection."""
|
||||||
|
websocket = state.ws
|
||||||
|
|
||||||
|
await websocket.send({
|
||||||
|
'op': OP.RECONNECT
|
||||||
|
})
|
||||||
|
|
||||||
|
# wait 200ms
|
||||||
|
# so that the client has time to process
|
||||||
|
# our payload then close the connection
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# try to close the connection ourselves
|
||||||
|
await websocket.ws.close(
|
||||||
|
code=4000,
|
||||||
|
reason='litecord shutting down'
|
||||||
|
)
|
||||||
|
except ConnectionClosed:
|
||||||
|
log.info('client {} already closed', state)
|
||||||
|
|
||||||
|
def gen_close_tasks(self):
|
||||||
|
"""Generate the tasks that will order the clients
|
||||||
|
to reconnect.
|
||||||
|
|
||||||
|
This is required to be ran before :meth:`StateManager.close`,
|
||||||
|
since this function doesn't wait for the tasks to complete.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.accept_new = False
|
||||||
|
|
||||||
|
#: store the shutdown tasks
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for state in self.states_raw.values():
|
||||||
|
if not state.ws:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
|
self.shutdown_single(state)
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info('made {} shutdown tasks', len(tasks))
|
||||||
|
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the state manager."""
|
||||||
|
self.closed = True
|
||||||
|
|
|
||||||
|
|
@ -753,6 +753,15 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
async def listen_messages(self):
|
async def listen_messages(self):
|
||||||
"""Listen for messages coming in from the websocket."""
|
"""Listen for messages coming in from the websocket."""
|
||||||
|
|
||||||
|
# close anyone trying to login while the
|
||||||
|
# server is shutting down
|
||||||
|
if self.ext.state_manager.closed:
|
||||||
|
raise WebsocketClose(4000, 'state manager closed')
|
||||||
|
|
||||||
|
if not self.ext.state_manager.accept_new:
|
||||||
|
raise WebsocketClose(4000, 'state manager closed for new')
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
message = await self.ws.recv()
|
message = await self.ws.recv()
|
||||||
if len(message) > 4096:
|
if len(message) > 4096:
|
||||||
|
|
@ -762,6 +771,9 @@ class GatewayWebsocket:
|
||||||
await self.process_message(payload)
|
await self.process_message(payload)
|
||||||
|
|
||||||
def _cleanup(self):
|
def _cleanup(self):
|
||||||
|
for task in self.wsp.tasks.values():
|
||||||
|
task.cancel()
|
||||||
|
|
||||||
if self.state:
|
if self.state:
|
||||||
self.ext.state_manager.remove(self.state)
|
self.ext.state_manager.remove(self.state)
|
||||||
self.state.ws = None
|
self.state.ws = None
|
||||||
|
|
|
||||||
11
run.py
11
run.py
|
|
@ -131,6 +131,8 @@ async def app_before_serving():
|
||||||
async def _wrapper(ws, url):
|
async def _wrapper(ws, url):
|
||||||
# We wrap the main websocket_handler
|
# We wrap the main websocket_handler
|
||||||
# so we can pass quart's app object.
|
# so we can pass quart's app object.
|
||||||
|
|
||||||
|
# TODO: pass just the app object
|
||||||
await websocket_handler((app.db, app.state_manager, app.storage,
|
await websocket_handler((app.db, app.state_manager, app.storage,
|
||||||
app.loop, app.dispatcher, app.presence),
|
app.loop, app.dispatcher, app.presence),
|
||||||
ws, url)
|
ws, url)
|
||||||
|
|
@ -142,6 +144,15 @@ async def app_before_serving():
|
||||||
|
|
||||||
@app.after_serving
|
@app.after_serving
|
||||||
async def app_after_serving():
|
async def app_after_serving():
|
||||||
|
"""Shutdown tasks for the server."""
|
||||||
|
|
||||||
|
# first close all clients, then close db
|
||||||
|
tasks = app.state_manager.gen_close_tasks()
|
||||||
|
if tasks:
|
||||||
|
await asyncio.wait(tasks, loop=app.loop)
|
||||||
|
|
||||||
|
app.state_manager.close()
|
||||||
|
|
||||||
log.info('closing db')
|
log.info('closing db')
|
||||||
await app.db.close()
|
await app.db.close()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue