diff --git a/litecord/auth.py b/litecord/auth.py index e325d6b..1c95efb 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -1,5 +1,6 @@ import base64 import logging +import binascii from itsdangerous import Signer, BadSignature from quart import request, current_app as app @@ -10,19 +11,13 @@ from .errors import AuthError log = logging.getLogger(__name__) -async def token_check(): - """Check token information.""" - try: - token = request.headers['Authorization'] - except KeyError: - raise AuthError('No token provided') - +async def raw_token_check(token): user_id, _hmac = token.split('.') - user_id = base64.b64decode(user_id.encode('utf-8')) try: + user_id = base64.b64decode(user_id.encode('utf-8')) user_id = int(user_id) - except ValueError: + except (ValueError, binascii.Error): raise AuthError('Invalid user ID type') pwd_hash = await app.db.fetchval(""" @@ -43,3 +38,13 @@ async def token_check(): except BadSignature: log.warning('token fail for uid {user_id}') raise AuthError('Invalid token') + + +async def token_check(): + """Check token information.""" + try: + token = request.headers['Authorization'] + except KeyError: + raise AuthError('No token provided') + + await raw_token_check(token) diff --git a/litecord/gateway/gateway.py b/litecord/gateway/gateway.py index 6ecde18..8b6e8cf 100644 --- a/litecord/gateway/gateway.py +++ b/litecord/gateway/gateway.py @@ -2,7 +2,7 @@ import urllib.parse from .websocket import GatewayWebsocket -async def websocket_handler(ws, url): +async def websocket_handler(app, ws, url): qs = urllib.parse.parse_qs( urllib.parse.urlparse(url).query ) @@ -27,6 +27,6 @@ async def websocket_handler(ws, url): if gw_compress and gw_compress not in ('zlib-stream',): return await ws.close(1000, 'Invalid gateway compress') - gws = GatewayWebsocket(ws, v=gw_version, + gws = GatewayWebsocket(app, ws, v=gw_version, encoding=gw_encoding, compress=gw_compress) await gws.run() diff --git a/litecord/gateway/opcodes.py b/litecord/gateway/opcodes.py index 3bfa8c2..1cf83de 100644 --- a/litecord/gateway/opcodes.py +++ b/litecord/gateway/opcodes.py @@ -12,3 +12,4 @@ class OP: INVALID_SESSION = 9 HELLO = 10 HEARTBEAT_ACK = 11 + GUILD_SYNC = 12 diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index d98cd41..01e7686 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -1,3 +1,12 @@ +import hashlib +import os + + +def gen_session_id() -> str: + """Generate a random session ID.""" + return hashlib.sha1(os.urandom(256)).hexdigest() + + class GatewayState: """Main websocket state. diff --git a/litecord/gateway/state_man.py b/litecord/gateway/state_man.py new file mode 100644 index 0000000..0c1de68 --- /dev/null +++ b/litecord/gateway/state_man.py @@ -0,0 +1,16 @@ +from .state import GatewayState + + +class StateManager: + """Manager for gateway state information.""" + def __init__(self): + self.states = {} + + def insert(self, state: GatewayState): + """Insert a new state object.""" + user_states = self.states[state.user_id] + user_states[state.session_id] = state + + def fetch(self, user_id: int, session_id: str) -> GatewayState: + """Fetch a state object from the registry.""" + return self.states[user_id][session_id] diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 86d9673..c7616a6 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -1,10 +1,19 @@ import json +import logging +import collections import earl -from ..errors import WebsocketClose +from ..errors import WebsocketClose, AuthError +from ..auth import raw_token_check from .errors import DecodeError, UnknownOPCode from .opcodes import OP +from .state import GatewayState, gen_session_id + + +log = logging.getLogger(__name__) +WebsocketProperties = collections.namedtuple( + 'WebsocketProperties', 'v encoding compress') def encode_json(payload) -> str: @@ -25,16 +34,20 @@ def decode_etf(data): class GatewayWebsocket: """Main gateway websocket logic.""" - def __init__(self, ws, **kwargs): + def __init__(self, app, ws, **kwargs): + self.app = app self.ws = ws - self.version = kwargs.get('v', 6) - self.encoding = kwargs.get('encoding', 'json') - self.compress = kwargs.get('compress', None) - self.set_encoders() + self.wsp = WebsocketProperties(kwargs.get('v'), + kwargs.get('encoding', 'json'), + kwargs.get('compress', None)) - def set_encoders(self): - encoding = self.encoding + self.state = None + + self._set_encoders() + + def _set_encoders(self): + encoding = self.wsp.encoding encodings = { 'json': (encode_json, decode_json), @@ -43,7 +56,8 @@ class GatewayWebsocket: self.encoder, self.decoder = encodings[encoding] - async def send(self, payload): + async def send(self, payload: dict): + """Send a payload to the websocket""" encoded = self.encoder(payload) # TODO: compression @@ -51,7 +65,7 @@ class GatewayWebsocket: await self.ws.send(encoded) async def send_hello(self): - """Send the OP 10 Hello""" + """Send the OP 10 Hello packet over the websocket.""" await self.send({ 'op': OP.HELLO, 'd': { @@ -62,8 +76,57 @@ class GatewayWebsocket: } }) - async def handle_0(self, payload): - pass + async def dispatch(self, event, data): + """Dispatch an event to the websocket.""" + await self.send({ + 'op': OP.DISPATCH, + 't': event.upper(), + # 's': self.state.seq, + 'd': data, + }) + + async def handle_0(self, payload: dict): + """Handle the OP 0 Identify packet.""" + data = payload['d'] + try: + token, properties = data['token'], data['properties'] + except KeyError: + raise DecodeError('Invalid identify parameters') + + compress = data.get('compress', False) + large = data.get('large_threshold', 50) + + shard = data.get('shard', [0, 1]) + presence = data.get('presence') + + try: + user_id = await raw_token_check(token) + except AuthError: + raise WebsocketClose(4004, 'Authentication failed') + + session_id = gen_session_id() + + self.state = GatewayState( + session_id=session_id, + user_id=user_id, + properties=properties, + compress=compress, + large=large, + shard=shard, + presence=presence + ) + + self.app.state_manager.insert(self.state) + + # TODO: dispatch READY + await self.dispatch('READY', { + 'v': 6, + 'user': {'i': 'Boobs !! ! .........'}, + 'private_channels': [], + 'guilds': [], + 'session_id': session_id, + '_trace': ['despacito'] + }) async def process_message(self, payload): """Process a single message coming in from the client.""" @@ -96,5 +159,6 @@ class GatewayWebsocket: await self.send_hello() await self.listen_messages() except WebsocketClose as err: + log.warning(f'Closed a client, {self.state or ""} {err!r}') await self.ws.close(code=err.code, reason=err.reason) diff --git a/run.py b/run.py index 28f2898..b1b41b1 100644 --- a/run.py +++ b/run.py @@ -10,6 +10,7 @@ import config from litecord.blueprints import gateway, auth from litecord.gateway import websocket_handler from litecord.errors import LitecordError +from litecord.gateway.state_man import StateManager logging.basicConfig(level=logging.INFO) log = logging.getLogger(__name__) @@ -39,11 +40,18 @@ async def app_before_serving(): app.loop = asyncio.get_event_loop() g.loop = asyncio.get_event_loop() + app.state_manager = StateManager() + # start the websocket, etc host, port = app.config['WS_HOST'], app.config['WS_PORT'] log.info(f'starting websocket at {host} {port}') - ws_future = websockets.serve( - websocket_handler, host, port) + + async def _wrapper(ws, url): + # We wrap the main websocket_handler + # so we can pass quart's app object. + await websocket_handler(app, ws, url) + + ws_future = websockets.serve(_wrapper, host, port) await ws_future