gateway.gateway: add default args when connecting

- gateway.gateway: pass app instance instead of a 7-tuple
This commit is contained in:
Luna Mendes 2018-11-17 02:20:48 -03:00
parent 11d4b54f87
commit 00c976c552
3 changed files with 26 additions and 17 deletions

View File

@ -2,31 +2,39 @@ import urllib.parse
from .websocket import GatewayWebsocket from .websocket import GatewayWebsocket
async def websocket_handler(prop, ws, url): async def websocket_handler(app, ws, url):
qs = urllib.parse.parse_qs( """Main websocket handler, checks query arguments
when connecting to the gateway and spawns a
GatewayWebsocket instance for the connection."""
args = urllib.parse.parse_qs(
urllib.parse.urlparse(url).query urllib.parse.urlparse(url).query
) )
# pull a dict.get but in a really bad way.
try: try:
gw_version = qs['v'][0] gw_version = args['v'][0]
gw_encoding = qs['encoding'][0]
except (KeyError, IndexError): except (KeyError, IndexError):
return await ws.close(1000, 'Invalid query args') gw_version = '6'
if gw_version not in ('6',): try:
gw_encoding = args['encoding'][0]
except (KeyError, IndexError):
gw_encoding = 'json'
if gw_version not in ('6', '7'):
return await ws.close(1000, 'Invalid gateway version') return await ws.close(1000, 'Invalid gateway version')
if gw_encoding not in ('json', 'etf'): if gw_encoding not in ('json', 'etf'):
return await ws.close(1000, 'Invalid gateway encoding') return await ws.close(1000, 'Invalid gateway encoding')
try: try:
gw_compress = qs['compress'][0] gw_compress = args['compress'][0]
except (KeyError, IndexError): except (KeyError, IndexError):
gw_compress = None gw_compress = None
if gw_compress and gw_compress not in ('zlib-stream',): if gw_compress and gw_compress not in ('zlib-stream',):
return await ws.close(1000, 'Invalid gateway compress') return await ws.close(1000, 'Invalid gateway compress')
gws = GatewayWebsocket(ws, prop=prop, v=gw_version, gws = GatewayWebsocket(ws, app, v=gw_version,
encoding=gw_encoding, compress=gw_compress) encoding=gw_encoding, compress=gw_compress)
await gws.run() await gws.run()

View File

@ -29,7 +29,8 @@ WebsocketProperties = collections.namedtuple(
WebsocketObjects = collections.namedtuple( WebsocketObjects = collections.namedtuple(
'WebsocketObjects', ('db', 'state_manager', 'storage', 'WebsocketObjects', ('db', 'state_manager', 'storage',
'loop', 'dispatcher', 'presence', 'ratelimiter') 'loop', 'dispatcher', 'presence', 'ratelimiter',
'user_storage')
) )
@ -82,8 +83,13 @@ def decode_etf(data: bytes):
class GatewayWebsocket: class GatewayWebsocket:
"""Main gateway websocket logic.""" """Main gateway websocket logic."""
def __init__(self, ws, **kwargs): def __init__(self, ws, app, **kwargs):
self.ext = WebsocketObjects(*kwargs['prop']) self.ext = WebsocketObjects(
app.db, app.state_manager, app.storage, app.loop,
app.dispatcher, app.presence, app.ratelimiter,
app.user_storage
)
self.storage = self.ext.storage self.storage = self.ext.storage
self.presence = self.ext.presence self.presence = self.ext.presence
self.ws = ws self.ws = ws

7
run.py
View File

@ -220,12 +220,7 @@ async def app_before_serving():
async def _wrapper(ws, url): async def _wrapper(ws, url):
# We wrap the main websocket_handler # We wrap the main websocket_handler
# so we can pass quart's app object. # so we can pass quart's app object.
await websocket_handler(app, ws, url)
# TODO: pass just the app object
await websocket_handler((app.db, app.state_manager, app.storage,
app.loop, app.dispatcher, app.presence,
app.ratelimiter),
ws, url)
ws_future = websockets.serve(_wrapper, host, port) ws_future = websockets.serve(_wrapper, host, port)