diff --git a/Pipfile b/Pipfile index 7c0cdbb..b8057d1 100644 --- a/Pipfile +++ b/Pipfile @@ -9,6 +9,7 @@ itsdangerous = "==0.24" asyncpg = "==0.16.0" websockets = "==5.0.1" Quart = "==0.6.0" +Earl-ETF = "==2.1.2" [dev-packages] diff --git a/Pipfile.lock b/Pipfile.lock index 57da1a8..a3b7b20 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "e37a82fc53dadfc4b8ebbded3bb043d686fa5c2cde07b11589430586e236386b" + "sha256": "f4558d5a01c7a8954d1b4f60042c3599189df0e720c235c89bb41facb9d704ab" }, "host-environment-markers": { "implementation_name": "cpython", @@ -9,9 +9,9 @@ "os_name": "posix", "platform_machine": "x86_64", "platform_python_implementation": "CPython", - "platform_release": "4.16.13-2-ARCH", + "platform_release": "4.17.2-1-ARCH", "platform_system": "Linux", - "platform_version": "#1 SMP PREEMPT Fri Jun 1 18:46:11 UTC 2018", + "platform_version": "#1 SMP PREEMPT Sat Jun 16 11:08:59 UTC 2018", "python_full_version": "3.6.5", "python_version": "3.6", "sys_platform": "linux" @@ -53,6 +53,10 @@ "hashes": [], "version": "==6.7" }, + "earl-etf": { + "hashes": [], + "version": "==2.1.2" + }, "h11": { "hashes": [], "version": "==0.7.0" diff --git a/litecord/gateway/errors.py b/litecord/gateway/errors.py new file mode 100644 index 0000000..e877dc9 --- /dev/null +++ b/litecord/gateway/errors.py @@ -0,0 +1,17 @@ +from ..errors import WebsocketClose + + +class UnknownOPCode(WebsocketClose): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # hacky solution to + # decrease code repetition + self.args = [4001, self.args[0]] + + +class DecodeError(WebsocketClose): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.args = [4002, self.args[0]] diff --git a/litecord/gateway/gateway.py b/litecord/gateway/gateway.py index ec2cff6..6ecde18 100644 --- a/litecord/gateway/gateway.py +++ b/litecord/gateway/gateway.py @@ -1,4 +1,5 @@ import urllib.parse +from .websocket import GatewayWebsocket async def websocket_handler(ws, url): @@ -26,5 +27,6 @@ async def websocket_handler(ws, url): if gw_compress and gw_compress not in ('zlib-stream',): return await ws.close(1000, 'Invalid gateway compress') - await ws.close(code=1000, reason='ass') - return + gws = GatewayWebsocket(ws, v=gw_version, + encoding=gw_encoding, compress=gw_compress) + await gws.run() diff --git a/litecord/gateway/opcodes.py b/litecord/gateway/opcodes.py new file mode 100644 index 0000000..3bfa8c2 --- /dev/null +++ b/litecord/gateway/opcodes.py @@ -0,0 +1,14 @@ +class OP: + """Gateway OP codes.""" + DISPATCH = 0 + HEARTBEAT = 1 + IDENTIFY = 2 + STATUS_UPDATE = 3 + VOICE_UPDATE = 4 + VOICE_PING = 5 + RESUME = 6 + RECONNECT = 7 + REQ_GUILD_MEMBERS = 8 + INVALID_SESSION = 9 + HELLO = 10 + HEARTBEAT_ACK = 11 diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py new file mode 100644 index 0000000..d98cd41 --- /dev/null +++ b/litecord/gateway/state.py @@ -0,0 +1,7 @@ +class GatewayState: + """Main websocket state. + + Used to store all information tied to the websocket's session. + """ + def __init__(self, **kwargs): + pass diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py new file mode 100644 index 0000000..86d9673 --- /dev/null +++ b/litecord/gateway/websocket.py @@ -0,0 +1,100 @@ +import json + +import earl + +from ..errors import WebsocketClose +from .errors import DecodeError, UnknownOPCode +from .opcodes import OP + + +def encode_json(payload) -> str: + return json.dumps(payload) + + +def decode_json(data: str): + return json.loads(data) + + +def encode_etf(payload) -> str: + return earl.pack(payload) + + +def decode_etf(data): + return earl.unpack(data) + + +class GatewayWebsocket: + """Main gateway websocket logic.""" + def __init__(self, ws, **kwargs): + self.ws = ws + self.version = kwargs.get('v', 6) + self.encoding = kwargs.get('encoding', 'json') + self.compress = kwargs.get('compress', None) + + self.set_encoders() + + def set_encoders(self): + encoding = self.encoding + + encodings = { + 'json': (encode_json, decode_json), + 'etf': (encode_etf, decode_etf), + } + + self.encoder, self.decoder = encodings[encoding] + + async def send(self, payload): + encoded = self.encoder(payload) + + # TODO: compression + + await self.ws.send(encoded) + + async def send_hello(self): + """Send the OP 10 Hello""" + await self.send({ + 'op': OP.HELLO, + 'd': { + 'heartbeat_interval': 45000, + '_trace': [ + 'despacito' + ], + } + }) + + async def handle_0(self, payload): + pass + + async def process_message(self, payload): + """Process a single message coming in from the client.""" + try: + op_code = payload['op'] + except KeyError: + raise UnknownOPCode('No OP code') + + try: + handler = getattr(self, f'handle_{op_code}') + except AttributeError: + raise UnknownOPCode('Bad OP code') + + await handler(payload) + + async def listen_messages(self): + """Listen for messages coming in from the websocket.""" + while True: + message = await self.ws.recv() + if len(message) > 4096: + raise DecodeError('Payload length exceeded') + + payload = self.decoder(message) + await self.process_message(payload) + + async def run(self): + """Wrap listen_messages inside + a try/except block for WebsocketClose handling.""" + try: + await self.send_hello() + await self.listen_messages() + except WebsocketClose as err: + await self.ws.close(code=err.code, + reason=err.reason)