mirror of https://gitlab.com/litecord/litecord.git
gateway.websocket: add basics of identify
- auth: add raw_token_check - gateway.gateway: pass the app object to GatewayWebsocket - gateway.state: add gen_session_id() - gateway: add state_man
This commit is contained in:
parent
32b9698ea7
commit
77c5a101c6
|
|
@ -1,5 +1,6 @@
|
|||
import base64
|
||||
import logging
|
||||
import binascii
|
||||
|
||||
from itsdangerous import Signer, BadSignature
|
||||
from quart import request, current_app as app
|
||||
|
|
@ -10,19 +11,13 @@ from .errors import AuthError
|
|||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def token_check():
|
||||
"""Check token information."""
|
||||
try:
|
||||
token = request.headers['Authorization']
|
||||
except KeyError:
|
||||
raise AuthError('No token provided')
|
||||
|
||||
async def raw_token_check(token):
|
||||
user_id, _hmac = token.split('.')
|
||||
|
||||
user_id = base64.b64decode(user_id.encode('utf-8'))
|
||||
try:
|
||||
user_id = base64.b64decode(user_id.encode('utf-8'))
|
||||
user_id = int(user_id)
|
||||
except ValueError:
|
||||
except (ValueError, binascii.Error):
|
||||
raise AuthError('Invalid user ID type')
|
||||
|
||||
pwd_hash = await app.db.fetchval("""
|
||||
|
|
@ -43,3 +38,13 @@ async def token_check():
|
|||
except BadSignature:
|
||||
log.warning('token fail for uid {user_id}')
|
||||
raise AuthError('Invalid token')
|
||||
|
||||
|
||||
async def token_check():
|
||||
"""Check token information."""
|
||||
try:
|
||||
token = request.headers['Authorization']
|
||||
except KeyError:
|
||||
raise AuthError('No token provided')
|
||||
|
||||
await raw_token_check(token)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import urllib.parse
|
|||
from .websocket import GatewayWebsocket
|
||||
|
||||
|
||||
async def websocket_handler(ws, url):
|
||||
async def websocket_handler(app, ws, url):
|
||||
qs = urllib.parse.parse_qs(
|
||||
urllib.parse.urlparse(url).query
|
||||
)
|
||||
|
|
@ -27,6 +27,6 @@ async def websocket_handler(ws, url):
|
|||
if gw_compress and gw_compress not in ('zlib-stream',):
|
||||
return await ws.close(1000, 'Invalid gateway compress')
|
||||
|
||||
gws = GatewayWebsocket(ws, v=gw_version,
|
||||
gws = GatewayWebsocket(app, ws, v=gw_version,
|
||||
encoding=gw_encoding, compress=gw_compress)
|
||||
await gws.run()
|
||||
|
|
|
|||
|
|
@ -12,3 +12,4 @@ class OP:
|
|||
INVALID_SESSION = 9
|
||||
HELLO = 10
|
||||
HEARTBEAT_ACK = 11
|
||||
GUILD_SYNC = 12
|
||||
|
|
|
|||
|
|
@ -1,3 +1,12 @@
|
|||
import hashlib
|
||||
import os
|
||||
|
||||
|
||||
def gen_session_id() -> str:
|
||||
"""Generate a random session ID."""
|
||||
return hashlib.sha1(os.urandom(256)).hexdigest()
|
||||
|
||||
|
||||
class GatewayState:
|
||||
"""Main websocket state.
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
from .state import GatewayState
|
||||
|
||||
|
||||
class StateManager:
|
||||
"""Manager for gateway state information."""
|
||||
def __init__(self):
|
||||
self.states = {}
|
||||
|
||||
def insert(self, state: GatewayState):
|
||||
"""Insert a new state object."""
|
||||
user_states = self.states[state.user_id]
|
||||
user_states[state.session_id] = state
|
||||
|
||||
def fetch(self, user_id: int, session_id: str) -> GatewayState:
|
||||
"""Fetch a state object from the registry."""
|
||||
return self.states[user_id][session_id]
|
||||
|
|
@ -1,10 +1,19 @@
|
|||
import json
|
||||
import logging
|
||||
import collections
|
||||
|
||||
import earl
|
||||
|
||||
from ..errors import WebsocketClose
|
||||
from ..errors import WebsocketClose, AuthError
|
||||
from ..auth import raw_token_check
|
||||
from .errors import DecodeError, UnknownOPCode
|
||||
from .opcodes import OP
|
||||
from .state import GatewayState, gen_session_id
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
WebsocketProperties = collections.namedtuple(
|
||||
'WebsocketProperties', 'v encoding compress')
|
||||
|
||||
|
||||
def encode_json(payload) -> str:
|
||||
|
|
@ -25,16 +34,20 @@ def decode_etf(data):
|
|||
|
||||
class GatewayWebsocket:
|
||||
"""Main gateway websocket logic."""
|
||||
def __init__(self, ws, **kwargs):
|
||||
def __init__(self, app, ws, **kwargs):
|
||||
self.app = app
|
||||
self.ws = ws
|
||||
self.version = kwargs.get('v', 6)
|
||||
self.encoding = kwargs.get('encoding', 'json')
|
||||
self.compress = kwargs.get('compress', None)
|
||||
|
||||
self.set_encoders()
|
||||
self.wsp = WebsocketProperties(kwargs.get('v'),
|
||||
kwargs.get('encoding', 'json'),
|
||||
kwargs.get('compress', None))
|
||||
|
||||
def set_encoders(self):
|
||||
encoding = self.encoding
|
||||
self.state = None
|
||||
|
||||
self._set_encoders()
|
||||
|
||||
def _set_encoders(self):
|
||||
encoding = self.wsp.encoding
|
||||
|
||||
encodings = {
|
||||
'json': (encode_json, decode_json),
|
||||
|
|
@ -43,7 +56,8 @@ class GatewayWebsocket:
|
|||
|
||||
self.encoder, self.decoder = encodings[encoding]
|
||||
|
||||
async def send(self, payload):
|
||||
async def send(self, payload: dict):
|
||||
"""Send a payload to the websocket"""
|
||||
encoded = self.encoder(payload)
|
||||
|
||||
# TODO: compression
|
||||
|
|
@ -51,7 +65,7 @@ class GatewayWebsocket:
|
|||
await self.ws.send(encoded)
|
||||
|
||||
async def send_hello(self):
|
||||
"""Send the OP 10 Hello"""
|
||||
"""Send the OP 10 Hello packet over the websocket."""
|
||||
await self.send({
|
||||
'op': OP.HELLO,
|
||||
'd': {
|
||||
|
|
@ -62,8 +76,57 @@ class GatewayWebsocket:
|
|||
}
|
||||
})
|
||||
|
||||
async def handle_0(self, payload):
|
||||
pass
|
||||
async def dispatch(self, event, data):
|
||||
"""Dispatch an event to the websocket."""
|
||||
await self.send({
|
||||
'op': OP.DISPATCH,
|
||||
't': event.upper(),
|
||||
# 's': self.state.seq,
|
||||
'd': data,
|
||||
})
|
||||
|
||||
async def handle_0(self, payload: dict):
|
||||
"""Handle the OP 0 Identify packet."""
|
||||
data = payload['d']
|
||||
try:
|
||||
token, properties = data['token'], data['properties']
|
||||
except KeyError:
|
||||
raise DecodeError('Invalid identify parameters')
|
||||
|
||||
compress = data.get('compress', False)
|
||||
large = data.get('large_threshold', 50)
|
||||
|
||||
shard = data.get('shard', [0, 1])
|
||||
presence = data.get('presence')
|
||||
|
||||
try:
|
||||
user_id = await raw_token_check(token)
|
||||
except AuthError:
|
||||
raise WebsocketClose(4004, 'Authentication failed')
|
||||
|
||||
session_id = gen_session_id()
|
||||
|
||||
self.state = GatewayState(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
properties=properties,
|
||||
compress=compress,
|
||||
large=large,
|
||||
shard=shard,
|
||||
presence=presence
|
||||
)
|
||||
|
||||
self.app.state_manager.insert(self.state)
|
||||
|
||||
# TODO: dispatch READY
|
||||
await self.dispatch('READY', {
|
||||
'v': 6,
|
||||
'user': {'i': 'Boobs !! ! .........'},
|
||||
'private_channels': [],
|
||||
'guilds': [],
|
||||
'session_id': session_id,
|
||||
'_trace': ['despacito']
|
||||
})
|
||||
|
||||
async def process_message(self, payload):
|
||||
"""Process a single message coming in from the client."""
|
||||
|
|
@ -96,5 +159,6 @@ class GatewayWebsocket:
|
|||
await self.send_hello()
|
||||
await self.listen_messages()
|
||||
except WebsocketClose as err:
|
||||
log.warning(f'Closed a client, {self.state or "<none>"} {err!r}')
|
||||
await self.ws.close(code=err.code,
|
||||
reason=err.reason)
|
||||
|
|
|
|||
12
run.py
12
run.py
|
|
@ -10,6 +10,7 @@ import config
|
|||
from litecord.blueprints import gateway, auth
|
||||
from litecord.gateway import websocket_handler
|
||||
from litecord.errors import LitecordError
|
||||
from litecord.gateway.state_man import StateManager
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
log = logging.getLogger(__name__)
|
||||
|
|
@ -39,11 +40,18 @@ async def app_before_serving():
|
|||
app.loop = asyncio.get_event_loop()
|
||||
g.loop = asyncio.get_event_loop()
|
||||
|
||||
app.state_manager = StateManager()
|
||||
|
||||
# start the websocket, etc
|
||||
host, port = app.config['WS_HOST'], app.config['WS_PORT']
|
||||
log.info(f'starting websocket at {host} {port}')
|
||||
ws_future = websockets.serve(
|
||||
websocket_handler, host, port)
|
||||
|
||||
async def _wrapper(ws, url):
|
||||
# We wrap the main websocket_handler
|
||||
# so we can pass quart's app object.
|
||||
await websocket_handler(app, ws, url)
|
||||
|
||||
ws_future = websockets.serve(_wrapper, host, port)
|
||||
|
||||
await ws_future
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue