mirror of https://gitlab.com/litecord/litecord.git
gateway.websocket: add basics of identify
- auth: add raw_token_check - gateway.gateway: pass the app object to GatewayWebsocket - gateway.state: add gen_session_id() - gateway: add state_man
This commit is contained in:
parent
32b9698ea7
commit
77c5a101c6
|
|
@ -1,5 +1,6 @@
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
import binascii
|
||||||
|
|
||||||
from itsdangerous import Signer, BadSignature
|
from itsdangerous import Signer, BadSignature
|
||||||
from quart import request, current_app as app
|
from quart import request, current_app as app
|
||||||
|
|
@ -10,19 +11,13 @@ from .errors import AuthError
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def token_check():
|
async def raw_token_check(token):
|
||||||
"""Check token information."""
|
|
||||||
try:
|
|
||||||
token = request.headers['Authorization']
|
|
||||||
except KeyError:
|
|
||||||
raise AuthError('No token provided')
|
|
||||||
|
|
||||||
user_id, _hmac = token.split('.')
|
user_id, _hmac = token.split('.')
|
||||||
|
|
||||||
user_id = base64.b64decode(user_id.encode('utf-8'))
|
|
||||||
try:
|
try:
|
||||||
|
user_id = base64.b64decode(user_id.encode('utf-8'))
|
||||||
user_id = int(user_id)
|
user_id = int(user_id)
|
||||||
except ValueError:
|
except (ValueError, binascii.Error):
|
||||||
raise AuthError('Invalid user ID type')
|
raise AuthError('Invalid user ID type')
|
||||||
|
|
||||||
pwd_hash = await app.db.fetchval("""
|
pwd_hash = await app.db.fetchval("""
|
||||||
|
|
@ -43,3 +38,13 @@ async def token_check():
|
||||||
except BadSignature:
|
except BadSignature:
|
||||||
log.warning('token fail for uid {user_id}')
|
log.warning('token fail for uid {user_id}')
|
||||||
raise AuthError('Invalid token')
|
raise AuthError('Invalid token')
|
||||||
|
|
||||||
|
|
||||||
|
async def token_check():
|
||||||
|
"""Check token information."""
|
||||||
|
try:
|
||||||
|
token = request.headers['Authorization']
|
||||||
|
except KeyError:
|
||||||
|
raise AuthError('No token provided')
|
||||||
|
|
||||||
|
await raw_token_check(token)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import urllib.parse
|
||||||
from .websocket import GatewayWebsocket
|
from .websocket import GatewayWebsocket
|
||||||
|
|
||||||
|
|
||||||
async def websocket_handler(ws, url):
|
async def websocket_handler(app, ws, url):
|
||||||
qs = urllib.parse.parse_qs(
|
qs = urllib.parse.parse_qs(
|
||||||
urllib.parse.urlparse(url).query
|
urllib.parse.urlparse(url).query
|
||||||
)
|
)
|
||||||
|
|
@ -27,6 +27,6 @@ async def websocket_handler(ws, url):
|
||||||
if gw_compress and gw_compress not in ('zlib-stream',):
|
if gw_compress and gw_compress not in ('zlib-stream',):
|
||||||
return await ws.close(1000, 'Invalid gateway compress')
|
return await ws.close(1000, 'Invalid gateway compress')
|
||||||
|
|
||||||
gws = GatewayWebsocket(ws, v=gw_version,
|
gws = GatewayWebsocket(app, ws, v=gw_version,
|
||||||
encoding=gw_encoding, compress=gw_compress)
|
encoding=gw_encoding, compress=gw_compress)
|
||||||
await gws.run()
|
await gws.run()
|
||||||
|
|
|
||||||
|
|
@ -12,3 +12,4 @@ class OP:
|
||||||
INVALID_SESSION = 9
|
INVALID_SESSION = 9
|
||||||
HELLO = 10
|
HELLO = 10
|
||||||
HEARTBEAT_ACK = 11
|
HEARTBEAT_ACK = 11
|
||||||
|
GUILD_SYNC = 12
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,12 @@
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def gen_session_id() -> str:
|
||||||
|
"""Generate a random session ID."""
|
||||||
|
return hashlib.sha1(os.urandom(256)).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class GatewayState:
|
class GatewayState:
|
||||||
"""Main websocket state.
|
"""Main websocket state.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
from .state import GatewayState
|
||||||
|
|
||||||
|
|
||||||
|
class StateManager:
|
||||||
|
"""Manager for gateway state information."""
|
||||||
|
def __init__(self):
|
||||||
|
self.states = {}
|
||||||
|
|
||||||
|
def insert(self, state: GatewayState):
|
||||||
|
"""Insert a new state object."""
|
||||||
|
user_states = self.states[state.user_id]
|
||||||
|
user_states[state.session_id] = state
|
||||||
|
|
||||||
|
def fetch(self, user_id: int, session_id: str) -> GatewayState:
|
||||||
|
"""Fetch a state object from the registry."""
|
||||||
|
return self.states[user_id][session_id]
|
||||||
|
|
@ -1,10 +1,19 @@
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
|
import collections
|
||||||
|
|
||||||
import earl
|
import earl
|
||||||
|
|
||||||
from ..errors import WebsocketClose
|
from ..errors import WebsocketClose, AuthError
|
||||||
|
from ..auth import raw_token_check
|
||||||
from .errors import DecodeError, UnknownOPCode
|
from .errors import DecodeError, UnknownOPCode
|
||||||
from .opcodes import OP
|
from .opcodes import OP
|
||||||
|
from .state import GatewayState, gen_session_id
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
WebsocketProperties = collections.namedtuple(
|
||||||
|
'WebsocketProperties', 'v encoding compress')
|
||||||
|
|
||||||
|
|
||||||
def encode_json(payload) -> str:
|
def encode_json(payload) -> str:
|
||||||
|
|
@ -25,16 +34,20 @@ def decode_etf(data):
|
||||||
|
|
||||||
class GatewayWebsocket:
|
class GatewayWebsocket:
|
||||||
"""Main gateway websocket logic."""
|
"""Main gateway websocket logic."""
|
||||||
def __init__(self, ws, **kwargs):
|
def __init__(self, app, ws, **kwargs):
|
||||||
|
self.app = app
|
||||||
self.ws = ws
|
self.ws = ws
|
||||||
self.version = kwargs.get('v', 6)
|
|
||||||
self.encoding = kwargs.get('encoding', 'json')
|
|
||||||
self.compress = kwargs.get('compress', None)
|
|
||||||
|
|
||||||
self.set_encoders()
|
self.wsp = WebsocketProperties(kwargs.get('v'),
|
||||||
|
kwargs.get('encoding', 'json'),
|
||||||
|
kwargs.get('compress', None))
|
||||||
|
|
||||||
def set_encoders(self):
|
self.state = None
|
||||||
encoding = self.encoding
|
|
||||||
|
self._set_encoders()
|
||||||
|
|
||||||
|
def _set_encoders(self):
|
||||||
|
encoding = self.wsp.encoding
|
||||||
|
|
||||||
encodings = {
|
encodings = {
|
||||||
'json': (encode_json, decode_json),
|
'json': (encode_json, decode_json),
|
||||||
|
|
@ -43,7 +56,8 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
self.encoder, self.decoder = encodings[encoding]
|
self.encoder, self.decoder = encodings[encoding]
|
||||||
|
|
||||||
async def send(self, payload):
|
async def send(self, payload: dict):
|
||||||
|
"""Send a payload to the websocket"""
|
||||||
encoded = self.encoder(payload)
|
encoded = self.encoder(payload)
|
||||||
|
|
||||||
# TODO: compression
|
# TODO: compression
|
||||||
|
|
@ -51,7 +65,7 @@ class GatewayWebsocket:
|
||||||
await self.ws.send(encoded)
|
await self.ws.send(encoded)
|
||||||
|
|
||||||
async def send_hello(self):
|
async def send_hello(self):
|
||||||
"""Send the OP 10 Hello"""
|
"""Send the OP 10 Hello packet over the websocket."""
|
||||||
await self.send({
|
await self.send({
|
||||||
'op': OP.HELLO,
|
'op': OP.HELLO,
|
||||||
'd': {
|
'd': {
|
||||||
|
|
@ -62,8 +76,57 @@ class GatewayWebsocket:
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
async def handle_0(self, payload):
|
async def dispatch(self, event, data):
|
||||||
pass
|
"""Dispatch an event to the websocket."""
|
||||||
|
await self.send({
|
||||||
|
'op': OP.DISPATCH,
|
||||||
|
't': event.upper(),
|
||||||
|
# 's': self.state.seq,
|
||||||
|
'd': data,
|
||||||
|
})
|
||||||
|
|
||||||
|
async def handle_0(self, payload: dict):
|
||||||
|
"""Handle the OP 0 Identify packet."""
|
||||||
|
data = payload['d']
|
||||||
|
try:
|
||||||
|
token, properties = data['token'], data['properties']
|
||||||
|
except KeyError:
|
||||||
|
raise DecodeError('Invalid identify parameters')
|
||||||
|
|
||||||
|
compress = data.get('compress', False)
|
||||||
|
large = data.get('large_threshold', 50)
|
||||||
|
|
||||||
|
shard = data.get('shard', [0, 1])
|
||||||
|
presence = data.get('presence')
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_id = await raw_token_check(token)
|
||||||
|
except AuthError:
|
||||||
|
raise WebsocketClose(4004, 'Authentication failed')
|
||||||
|
|
||||||
|
session_id = gen_session_id()
|
||||||
|
|
||||||
|
self.state = GatewayState(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
properties=properties,
|
||||||
|
compress=compress,
|
||||||
|
large=large,
|
||||||
|
shard=shard,
|
||||||
|
presence=presence
|
||||||
|
)
|
||||||
|
|
||||||
|
self.app.state_manager.insert(self.state)
|
||||||
|
|
||||||
|
# TODO: dispatch READY
|
||||||
|
await self.dispatch('READY', {
|
||||||
|
'v': 6,
|
||||||
|
'user': {'i': 'Boobs !! ! .........'},
|
||||||
|
'private_channels': [],
|
||||||
|
'guilds': [],
|
||||||
|
'session_id': session_id,
|
||||||
|
'_trace': ['despacito']
|
||||||
|
})
|
||||||
|
|
||||||
async def process_message(self, payload):
|
async def process_message(self, payload):
|
||||||
"""Process a single message coming in from the client."""
|
"""Process a single message coming in from the client."""
|
||||||
|
|
@ -96,5 +159,6 @@ class GatewayWebsocket:
|
||||||
await self.send_hello()
|
await self.send_hello()
|
||||||
await self.listen_messages()
|
await self.listen_messages()
|
||||||
except WebsocketClose as err:
|
except WebsocketClose as err:
|
||||||
|
log.warning(f'Closed a client, {self.state or "<none>"} {err!r}')
|
||||||
await self.ws.close(code=err.code,
|
await self.ws.close(code=err.code,
|
||||||
reason=err.reason)
|
reason=err.reason)
|
||||||
|
|
|
||||||
12
run.py
12
run.py
|
|
@ -10,6 +10,7 @@ import config
|
||||||
from litecord.blueprints import gateway, auth
|
from litecord.blueprints import gateway, auth
|
||||||
from litecord.gateway import websocket_handler
|
from litecord.gateway import websocket_handler
|
||||||
from litecord.errors import LitecordError
|
from litecord.errors import LitecordError
|
||||||
|
from litecord.gateway.state_man import StateManager
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
@ -39,11 +40,18 @@ async def app_before_serving():
|
||||||
app.loop = asyncio.get_event_loop()
|
app.loop = asyncio.get_event_loop()
|
||||||
g.loop = asyncio.get_event_loop()
|
g.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
app.state_manager = StateManager()
|
||||||
|
|
||||||
# start the websocket, etc
|
# start the websocket, etc
|
||||||
host, port = app.config['WS_HOST'], app.config['WS_PORT']
|
host, port = app.config['WS_HOST'], app.config['WS_PORT']
|
||||||
log.info(f'starting websocket at {host} {port}')
|
log.info(f'starting websocket at {host} {port}')
|
||||||
ws_future = websockets.serve(
|
|
||||||
websocket_handler, host, port)
|
async def _wrapper(ws, url):
|
||||||
|
# We wrap the main websocket_handler
|
||||||
|
# so we can pass quart's app object.
|
||||||
|
await websocket_handler(app, ws, url)
|
||||||
|
|
||||||
|
ws_future = websockets.serve(_wrapper, host, port)
|
||||||
|
|
||||||
await ws_future
|
await ws_future
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue