gateway.websocket: add basics of identify

- auth: add raw_token_check
 - gateway.gateway: pass the app object to GatewayWebsocket
 - gateway.state: add gen_session_id()
 - gateway: add state_man
This commit is contained in:
Luna Mendes 2018-06-19 19:09:12 -03:00
parent 32b9698ea7
commit 77c5a101c6
7 changed files with 128 additions and 25 deletions

View File

@ -1,5 +1,6 @@
import base64
import logging
import binascii
from itsdangerous import Signer, BadSignature
from quart import request, current_app as app
@ -10,19 +11,13 @@ from .errors import AuthError
log = logging.getLogger(__name__)
async def token_check():
"""Check token information."""
try:
token = request.headers['Authorization']
except KeyError:
raise AuthError('No token provided')
async def raw_token_check(token):
user_id, _hmac = token.split('.')
user_id = base64.b64decode(user_id.encode('utf-8'))
try:
user_id = base64.b64decode(user_id.encode('utf-8'))
user_id = int(user_id)
except ValueError:
except (ValueError, binascii.Error):
raise AuthError('Invalid user ID type')
pwd_hash = await app.db.fetchval("""
@ -43,3 +38,13 @@ async def token_check():
except BadSignature:
log.warning('token fail for uid {user_id}')
raise AuthError('Invalid token')
async def token_check():
"""Check token information."""
try:
token = request.headers['Authorization']
except KeyError:
raise AuthError('No token provided')
await raw_token_check(token)

View File

@ -2,7 +2,7 @@ import urllib.parse
from .websocket import GatewayWebsocket
async def websocket_handler(ws, url):
async def websocket_handler(app, ws, url):
qs = urllib.parse.parse_qs(
urllib.parse.urlparse(url).query
)
@ -27,6 +27,6 @@ async def websocket_handler(ws, url):
if gw_compress and gw_compress not in ('zlib-stream',):
return await ws.close(1000, 'Invalid gateway compress')
gws = GatewayWebsocket(ws, v=gw_version,
gws = GatewayWebsocket(app, ws, v=gw_version,
encoding=gw_encoding, compress=gw_compress)
await gws.run()

View File

@ -12,3 +12,4 @@ class OP:
INVALID_SESSION = 9
HELLO = 10
HEARTBEAT_ACK = 11
GUILD_SYNC = 12

View File

@ -1,3 +1,12 @@
import hashlib
import os
def gen_session_id() -> str:
"""Generate a random session ID."""
return hashlib.sha1(os.urandom(256)).hexdigest()
class GatewayState:
"""Main websocket state.

View File

@ -0,0 +1,16 @@
from .state import GatewayState
class StateManager:
"""Manager for gateway state information."""
def __init__(self):
self.states = {}
def insert(self, state: GatewayState):
"""Insert a new state object."""
user_states = self.states[state.user_id]
user_states[state.session_id] = state
def fetch(self, user_id: int, session_id: str) -> GatewayState:
"""Fetch a state object from the registry."""
return self.states[user_id][session_id]

View File

@ -1,10 +1,19 @@
import json
import logging
import collections
import earl
from ..errors import WebsocketClose
from ..errors import WebsocketClose, AuthError
from ..auth import raw_token_check
from .errors import DecodeError, UnknownOPCode
from .opcodes import OP
from .state import GatewayState, gen_session_id
log = logging.getLogger(__name__)
WebsocketProperties = collections.namedtuple(
'WebsocketProperties', 'v encoding compress')
def encode_json(payload) -> str:
@ -25,16 +34,20 @@ def decode_etf(data):
class GatewayWebsocket:
"""Main gateway websocket logic."""
def __init__(self, ws, **kwargs):
def __init__(self, app, ws, **kwargs):
self.app = app
self.ws = ws
self.version = kwargs.get('v', 6)
self.encoding = kwargs.get('encoding', 'json')
self.compress = kwargs.get('compress', None)
self.set_encoders()
self.wsp = WebsocketProperties(kwargs.get('v'),
kwargs.get('encoding', 'json'),
kwargs.get('compress', None))
def set_encoders(self):
encoding = self.encoding
self.state = None
self._set_encoders()
def _set_encoders(self):
encoding = self.wsp.encoding
encodings = {
'json': (encode_json, decode_json),
@ -43,7 +56,8 @@ class GatewayWebsocket:
self.encoder, self.decoder = encodings[encoding]
async def send(self, payload):
async def send(self, payload: dict):
"""Send a payload to the websocket"""
encoded = self.encoder(payload)
# TODO: compression
@ -51,7 +65,7 @@ class GatewayWebsocket:
await self.ws.send(encoded)
async def send_hello(self):
"""Send the OP 10 Hello"""
"""Send the OP 10 Hello packet over the websocket."""
await self.send({
'op': OP.HELLO,
'd': {
@ -62,8 +76,57 @@ class GatewayWebsocket:
}
})
async def handle_0(self, payload):
pass
async def dispatch(self, event, data):
"""Dispatch an event to the websocket."""
await self.send({
'op': OP.DISPATCH,
't': event.upper(),
# 's': self.state.seq,
'd': data,
})
async def handle_0(self, payload: dict):
"""Handle the OP 0 Identify packet."""
data = payload['d']
try:
token, properties = data['token'], data['properties']
except KeyError:
raise DecodeError('Invalid identify parameters')
compress = data.get('compress', False)
large = data.get('large_threshold', 50)
shard = data.get('shard', [0, 1])
presence = data.get('presence')
try:
user_id = await raw_token_check(token)
except AuthError:
raise WebsocketClose(4004, 'Authentication failed')
session_id = gen_session_id()
self.state = GatewayState(
session_id=session_id,
user_id=user_id,
properties=properties,
compress=compress,
large=large,
shard=shard,
presence=presence
)
self.app.state_manager.insert(self.state)
# TODO: dispatch READY
await self.dispatch('READY', {
'v': 6,
'user': {'i': 'Boobs !! ! .........'},
'private_channels': [],
'guilds': [],
'session_id': session_id,
'_trace': ['despacito']
})
async def process_message(self, payload):
"""Process a single message coming in from the client."""
@ -96,5 +159,6 @@ class GatewayWebsocket:
await self.send_hello()
await self.listen_messages()
except WebsocketClose as err:
log.warning(f'Closed a client, {self.state or "<none>"} {err!r}')
await self.ws.close(code=err.code,
reason=err.reason)

12
run.py
View File

@ -10,6 +10,7 @@ import config
from litecord.blueprints import gateway, auth
from litecord.gateway import websocket_handler
from litecord.errors import LitecordError
from litecord.gateway.state_man import StateManager
logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)
@ -39,11 +40,18 @@ async def app_before_serving():
app.loop = asyncio.get_event_loop()
g.loop = asyncio.get_event_loop()
app.state_manager = StateManager()
# start the websocket, etc
host, port = app.config['WS_HOST'], app.config['WS_PORT']
log.info(f'starting websocket at {host} {port}')
ws_future = websockets.serve(
websocket_handler, host, port)
async def _wrapper(ws, url):
# We wrap the main websocket_handler
# so we can pass quart's app object.
await websocket_handler(app, ws, url)
ws_future = websockets.serve(_wrapper, host, port)
await ws_future