gateway.websocket: detach app object from GatewayWebsocket

It doesn't work since quart's objects only work with stuff
that is already from quart, e.g the current_app stuff
requires you to be inside a special hidden context
that only quart functions get.

Gateway code is detached from quart since quart's websocket
stuff can't handle custom error codes.

 - auth: optional db detach
 - gateway.errors: add InvalidShard, ShardingRequired
 - gateway.gateway: pass asyncpg connection and StateManager
 - gateway.state: add repr, etc
 - gateway.state_man: add remove(), fetch_states()
This commit is contained in:
Luna Mendes 2018-06-19 21:05:26 -03:00
parent 77c5a101c6
commit 6f0528eaec
7 changed files with 119 additions and 27 deletions

View File

@ -11,7 +11,8 @@ from .errors import AuthError
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
async def raw_token_check(token): async def raw_token_check(token, db=None):
db = db or app.db
user_id, _hmac = token.split('.') user_id, _hmac = token.split('.')
try: try:
@ -20,7 +21,7 @@ async def raw_token_check(token):
except (ValueError, binascii.Error): except (ValueError, binascii.Error):
raise AuthError('Invalid user ID type') raise AuthError('Invalid user ID type')
pwd_hash = await app.db.fetchval(""" pwd_hash = await db.fetchval("""
SELECT password_hash SELECT password_hash
FROM users FROM users
WHERE id = $1 WHERE id = $1

View File

@ -15,3 +15,17 @@ class DecodeError(WebsocketClose):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.args = [4002, self.args[0]] self.args = [4002, self.args[0]]
class InvalidShard(WebsocketClose):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.args = [4010, self.args[0]]
class ShardingRequired(WebsocketClose):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.args = [4011, self.args[0]]

View File

@ -2,7 +2,7 @@ import urllib.parse
from .websocket import GatewayWebsocket from .websocket import GatewayWebsocket
async def websocket_handler(app, ws, url): async def websocket_handler(db, sm, ws, url):
qs = urllib.parse.parse_qs( qs = urllib.parse.parse_qs(
urllib.parse.urlparse(url).query urllib.parse.urlparse(url).query
) )
@ -27,6 +27,6 @@ async def websocket_handler(app, ws, url):
if gw_compress and gw_compress not in ('zlib-stream',): if gw_compress and gw_compress not in ('zlib-stream',):
return await ws.close(1000, 'Invalid gateway compress') return await ws.close(1000, 'Invalid gateway compress')
gws = GatewayWebsocket(app, ws, v=gw_version, gws = GatewayWebsocket(sm, db, ws, v=gw_version,
encoding=gw_encoding, compress=gw_compress) encoding=gw_encoding, compress=gw_compress)
await gws.run() await gws.run()

View File

@ -13,4 +13,17 @@ class GatewayState:
Used to store all information tied to the websocket's session. Used to store all information tied to the websocket's session.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
pass self.session_id = kwargs.get('session_id', gen_session_id())
self.seq = kwargs.get('seq', 0)
self.shard = kwargs.get('shard', [0, 1])
self.user_id = kwargs.get('user_id')
self.ws = None
for key in kwargs:
value = kwargs[key]
self.__dict__[key] = value
def __repr__(self):
return (f'GatewayState<session={self.session_id} seq={self.seq} '
f'shard={self.shard} uid={self.user_id}>')

View File

@ -1,16 +1,49 @@
import logging
from typing import List
from collections import defaultdict
from .state import GatewayState from .state import GatewayState
log = logging.getLogger(__name__)
class StateManager: class StateManager:
"""Manager for gateway state information.""" """Manager for gateway state information."""
def __init__(self): def __init__(self):
self.states = {} self.states = defaultdict(dict)
def insert(self, state: GatewayState): def insert(self, state: GatewayState):
"""Insert a new state object.""" """Insert a new state object."""
user_states = self.states[state.user_id] user_states = self.states[state.user_id]
log.info(f'Inserting state {state!r}')
user_states[state.session_id] = state user_states[state.session_id] = state
def fetch(self, user_id: int, session_id: str) -> GatewayState: def fetch(self, user_id: int, session_id: str) -> GatewayState:
"""Fetch a state object from the registry.""" """Fetch a state object from the registry."""
return self.states[user_id][session_id] return self.states[user_id][session_id]
def remove(self, state):
"""Remove a state from the registry"""
if not state:
return
try:
log.info(f'Removing state {state!r}')
self.states[state.user_id].pop(state.session_id)
except KeyError:
pass
def fetch_states(self, user_id, guild_id) -> List[GatewayState]:
"""Fetch all states that are tied to a guild."""
states = []
for state in self.states[user_id]:
shard_id = (guild_id >> 22) % state.shard_count
if shard_id == state.current_shard:
states.append(state)
return states

View File

@ -6,7 +6,9 @@ import earl
from ..errors import WebsocketClose, AuthError from ..errors import WebsocketClose, AuthError
from ..auth import raw_token_check from ..auth import raw_token_check
from .errors import DecodeError, UnknownOPCode from .errors import DecodeError, UnknownOPCode, \
InvalidShard, ShardingRequired
from .opcodes import OP from .opcodes import OP
from .state import GatewayState, gen_session_id from .state import GatewayState, gen_session_id
@ -34,8 +36,9 @@ def decode_etf(data):
class GatewayWebsocket: class GatewayWebsocket:
"""Main gateway websocket logic.""" """Main gateway websocket logic."""
def __init__(self, app, ws, **kwargs): def __init__(self, sm, db, ws, **kwargs):
self.app = app self.state_manager = sm
self.db = db
self.ws = ws self.ws = ws
self.wsp = WebsocketProperties(kwargs.get('v'), self.wsp = WebsocketProperties(kwargs.get('v'),
@ -78,13 +81,47 @@ class GatewayWebsocket:
async def dispatch(self, event, data): async def dispatch(self, event, data):
"""Dispatch an event to the websocket.""" """Dispatch an event to the websocket."""
self.state.seq += 1
await self.send({ await self.send({
'op': OP.DISPATCH, 'op': OP.DISPATCH,
't': event.upper(), 't': event.upper(),
# 's': self.state.seq, 's': self.state.seq,
'd': data, 'd': data,
}) })
async def dispatch_ready(self):
await self.dispatch('READY', {
'v': 6,
'user': {'i': 'Boobs !! ! .........'},
'private_channels': [],
'guilds': [],
'session_id': self.state.session_id,
'_trace': ['despacito']
})
async def _check_shards(self):
shard = self.state.shard
current_shard, shard_count = shard
guilds = await self.db.fetchval("""
SELECT COUNT(*)
FROM members
WHERE user_id = $1
""", self.state.user_id)
recommended = max(int(guilds / 1200), 1)
if shard_count < recommended:
raise ShardingRequired('Too many guilds for shard '
f'{current_shard}')
if guilds / shard_count > 0.8:
raise ShardingRequired('Too many shards.')
if current_shard > shard_count:
raise InvaildShard('Shard count > Total shards')
async def handle_0(self, payload: dict): async def handle_0(self, payload: dict):
"""Handle the OP 0 Identify packet.""" """Handle the OP 0 Identify packet."""
data = payload['d'] data = payload['d']
@ -100,33 +137,25 @@ class GatewayWebsocket:
presence = data.get('presence') presence = data.get('presence')
try: try:
user_id = await raw_token_check(token) user_id = await raw_token_check(token, self.db)
except AuthError: except AuthError:
raise WebsocketClose(4004, 'Authentication failed') raise WebsocketClose(4004, 'Authentication failed')
session_id = gen_session_id()
self.state = GatewayState( self.state = GatewayState(
session_id=session_id,
user_id=user_id, user_id=user_id,
properties=properties, properties=properties,
compress=compress, compress=compress,
large=large, large=large,
shard=shard, shard=shard,
presence=presence presence=presence,
) )
self.app.state_manager.insert(self.state) self.state.ws = self
# TODO: dispatch READY await self._check_shards()
await self.dispatch('READY', {
'v': 6, self.state_manager.insert(self.state)
'user': {'i': 'Boobs !! ! .........'}, await self.dispatch_ready()
'private_channels': [],
'guilds': [],
'session_id': session_id,
'_trace': ['despacito']
})
async def process_message(self, payload): async def process_message(self, payload):
"""Process a single message coming in from the client.""" """Process a single message coming in from the client."""
@ -159,6 +188,8 @@ class GatewayWebsocket:
await self.send_hello() await self.send_hello()
await self.listen_messages() await self.listen_messages()
except WebsocketClose as err: except WebsocketClose as err:
log.warning(f'Closed a client, {self.state or "<none>"} {err!r}') log.warning(f'Closed a client, state={self.state or "<none>"} '
f'{err!r}')
await self.ws.close(code=err.code, await self.ws.close(code=err.code,
reason=err.reason) reason=err.reason)

2
run.py
View File

@ -49,7 +49,7 @@ async def app_before_serving():
async def _wrapper(ws, url): async def _wrapper(ws, url):
# We wrap the main websocket_handler # We wrap the main websocket_handler
# so we can pass quart's app object. # so we can pass quart's app object.
await websocket_handler(app, ws, url) await websocket_handler(app.db, app.state_manager, ws, url)
ws_future = websockets.serve(_wrapper, host, port) ws_future = websockets.serve(_wrapper, host, port)