From 00c976c5526a97fc89de5cecec456c60f1e48a65 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sat, 17 Nov 2018 02:20:48 -0300 Subject: [PATCH] gateway.gateway: add default args when connecting - gateway.gateway: pass app instance instead of a 7-tuple --- litecord/gateway/gateway.py | 24 ++++++++++++++++-------- litecord/gateway/websocket.py | 12 +++++++++--- run.py | 7 +------ 3 files changed, 26 insertions(+), 17 deletions(-) diff --git a/litecord/gateway/gateway.py b/litecord/gateway/gateway.py index fe86226..9d5a354 100644 --- a/litecord/gateway/gateway.py +++ b/litecord/gateway/gateway.py @@ -2,31 +2,39 @@ import urllib.parse from .websocket import GatewayWebsocket -async def websocket_handler(prop, ws, url): - qs = urllib.parse.parse_qs( +async def websocket_handler(app, ws, url): + """Main websocket handler, checks query arguments + when connecting to the gateway and spawns a + GatewayWebsocket instance for the connection.""" + args = urllib.parse.parse_qs( urllib.parse.urlparse(url).query ) + # pull a dict.get but in a really bad way. try: - gw_version = qs['v'][0] - gw_encoding = qs['encoding'][0] + gw_version = args['v'][0] except (KeyError, IndexError): - return await ws.close(1000, 'Invalid query args') + gw_version = '6' - if gw_version not in ('6',): + try: + gw_encoding = args['encoding'][0] + except (KeyError, IndexError): + gw_encoding = 'json' + + if gw_version not in ('6', '7'): return await ws.close(1000, 'Invalid gateway version') if gw_encoding not in ('json', 'etf'): return await ws.close(1000, 'Invalid gateway encoding') try: - gw_compress = qs['compress'][0] + gw_compress = args['compress'][0] except (KeyError, IndexError): gw_compress = None if gw_compress and gw_compress not in ('zlib-stream',): return await ws.close(1000, 'Invalid gateway compress') - gws = GatewayWebsocket(ws, prop=prop, v=gw_version, + gws = GatewayWebsocket(ws, app, v=gw_version, encoding=gw_encoding, compress=gw_compress) await gws.run() diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 32531b1..6cc65a2 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -29,7 +29,8 @@ WebsocketProperties = collections.namedtuple( WebsocketObjects = collections.namedtuple( 'WebsocketObjects', ('db', 'state_manager', 'storage', - 'loop', 'dispatcher', 'presence', 'ratelimiter') + 'loop', 'dispatcher', 'presence', 'ratelimiter', + 'user_storage') ) @@ -82,8 +83,13 @@ def decode_etf(data: bytes): class GatewayWebsocket: """Main gateway websocket logic.""" - def __init__(self, ws, **kwargs): - self.ext = WebsocketObjects(*kwargs['prop']) + def __init__(self, ws, app, **kwargs): + self.ext = WebsocketObjects( + app.db, app.state_manager, app.storage, app.loop, + app.dispatcher, app.presence, app.ratelimiter, + app.user_storage + ) + self.storage = self.ext.storage self.presence = self.ext.presence self.ws = ws diff --git a/run.py b/run.py index f897cd3..07f1116 100644 --- a/run.py +++ b/run.py @@ -220,12 +220,7 @@ async def app_before_serving(): async def _wrapper(ws, url): # We wrap the main websocket_handler # so we can pass quart's app object. - - # TODO: pass just the app object - await websocket_handler((app.db, app.state_manager, app.storage, - app.loop, app.dispatcher, app.presence, - app.ratelimiter), - ws, url) + await websocket_handler(app, ws, url) ws_future = websockets.serve(_wrapper, host, port)