mirror of https://gitlab.com/litecord/litecord.git
gateway.gateway: add default args when connecting
- gateway.gateway: pass app instance instead of a 7-tuple
This commit is contained in:
parent
11d4b54f87
commit
00c976c552
|
|
@ -2,31 +2,39 @@ import urllib.parse
|
||||||
from .websocket import GatewayWebsocket
|
from .websocket import GatewayWebsocket
|
||||||
|
|
||||||
|
|
||||||
async def websocket_handler(prop, ws, url):
|
async def websocket_handler(app, ws, url):
|
||||||
qs = urllib.parse.parse_qs(
|
"""Main websocket handler, checks query arguments
|
||||||
|
when connecting to the gateway and spawns a
|
||||||
|
GatewayWebsocket instance for the connection."""
|
||||||
|
args = urllib.parse.parse_qs(
|
||||||
urllib.parse.urlparse(url).query
|
urllib.parse.urlparse(url).query
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pull a dict.get but in a really bad way.
|
||||||
try:
|
try:
|
||||||
gw_version = qs['v'][0]
|
gw_version = args['v'][0]
|
||||||
gw_encoding = qs['encoding'][0]
|
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
return await ws.close(1000, 'Invalid query args')
|
gw_version = '6'
|
||||||
|
|
||||||
if gw_version not in ('6',):
|
try:
|
||||||
|
gw_encoding = args['encoding'][0]
|
||||||
|
except (KeyError, IndexError):
|
||||||
|
gw_encoding = 'json'
|
||||||
|
|
||||||
|
if gw_version not in ('6', '7'):
|
||||||
return await ws.close(1000, 'Invalid gateway version')
|
return await ws.close(1000, 'Invalid gateway version')
|
||||||
|
|
||||||
if gw_encoding not in ('json', 'etf'):
|
if gw_encoding not in ('json', 'etf'):
|
||||||
return await ws.close(1000, 'Invalid gateway encoding')
|
return await ws.close(1000, 'Invalid gateway encoding')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
gw_compress = qs['compress'][0]
|
gw_compress = args['compress'][0]
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
gw_compress = None
|
gw_compress = None
|
||||||
|
|
||||||
if gw_compress and gw_compress not in ('zlib-stream',):
|
if gw_compress and gw_compress not in ('zlib-stream',):
|
||||||
return await ws.close(1000, 'Invalid gateway compress')
|
return await ws.close(1000, 'Invalid gateway compress')
|
||||||
|
|
||||||
gws = GatewayWebsocket(ws, prop=prop, v=gw_version,
|
gws = GatewayWebsocket(ws, app, v=gw_version,
|
||||||
encoding=gw_encoding, compress=gw_compress)
|
encoding=gw_encoding, compress=gw_compress)
|
||||||
await gws.run()
|
await gws.run()
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,8 @@ WebsocketProperties = collections.namedtuple(
|
||||||
|
|
||||||
WebsocketObjects = collections.namedtuple(
|
WebsocketObjects = collections.namedtuple(
|
||||||
'WebsocketObjects', ('db', 'state_manager', 'storage',
|
'WebsocketObjects', ('db', 'state_manager', 'storage',
|
||||||
'loop', 'dispatcher', 'presence', 'ratelimiter')
|
'loop', 'dispatcher', 'presence', 'ratelimiter',
|
||||||
|
'user_storage')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -82,8 +83,13 @@ def decode_etf(data: bytes):
|
||||||
class GatewayWebsocket:
|
class GatewayWebsocket:
|
||||||
"""Main gateway websocket logic."""
|
"""Main gateway websocket logic."""
|
||||||
|
|
||||||
def __init__(self, ws, **kwargs):
|
def __init__(self, ws, app, **kwargs):
|
||||||
self.ext = WebsocketObjects(*kwargs['prop'])
|
self.ext = WebsocketObjects(
|
||||||
|
app.db, app.state_manager, app.storage, app.loop,
|
||||||
|
app.dispatcher, app.presence, app.ratelimiter,
|
||||||
|
app.user_storage
|
||||||
|
)
|
||||||
|
|
||||||
self.storage = self.ext.storage
|
self.storage = self.ext.storage
|
||||||
self.presence = self.ext.presence
|
self.presence = self.ext.presence
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
|
|
|
||||||
7
run.py
7
run.py
|
|
@ -220,12 +220,7 @@ async def app_before_serving():
|
||||||
async def _wrapper(ws, url):
|
async def _wrapper(ws, url):
|
||||||
# We wrap the main websocket_handler
|
# We wrap the main websocket_handler
|
||||||
# so we can pass quart's app object.
|
# so we can pass quart's app object.
|
||||||
|
await websocket_handler(app, ws, url)
|
||||||
# TODO: pass just the app object
|
|
||||||
await websocket_handler((app.db, app.state_manager, app.storage,
|
|
||||||
app.loop, app.dispatcher, app.presence,
|
|
||||||
app.ratelimiter),
|
|
||||||
ws, url)
|
|
||||||
|
|
||||||
ws_future = websockets.serve(_wrapper, host, port)
|
ws_future = websockets.serve(_wrapper, host, port)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue