From d62db421b0afd9ae567f9f43112838a5d79a721f Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 24 Jun 2018 18:28:28 -0300 Subject: [PATCH] gateway.websocket: add basics of resuming - gateway.state: add PayloadStore - state: add last_seq - gateway.websocket: send string on non zlib-stream - gateway.websocket: add cleanup of state on ws close --- litecord/gateway/state.py | 14 ++++++ litecord/gateway/websocket.py | 91 ++++++++++++++++++++++++++++++----- 2 files changed, 94 insertions(+), 11 deletions(-) diff --git a/litecord/gateway/state.py b/litecord/gateway/state.py index e6101ab..ae58bd1 100644 --- a/litecord/gateway/state.py +++ b/litecord/gateway/state.py @@ -7,6 +7,18 @@ def gen_session_id() -> str: return hashlib.sha1(os.urandom(256)).hexdigest() +class PayloadStore: + """Store manager for payloads.""" + def __init__(self): + self.store = {} + + def __getitem__(self, opcode: int): + return self.store[opcode] + + def __setitem__(self, opcode: int, payload: dict): + self.store[opcode] = payload + + class GatewayState: """Main websocket state. @@ -16,9 +28,11 @@ class GatewayState: def __init__(self, **kwargs): self.session_id = kwargs.get('session_id', gen_session_id()) self.seq = kwargs.get('seq', 0) + self.last_seq = 0 self.shard = kwargs.get('shard', [0, 1]) self.user_id = kwargs.get('user_id') self.bot = kwargs.get('bot', False) + self.store = PayloadStore() for key in kwargs: value = kwargs[key] diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 5ca13e2..9e22b6e 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -2,9 +2,10 @@ import json import collections import pprint import zlib -from typing import List +from typing import List, Dict, Any import earl +import websockets from logbook import Logger from litecord.errors import WebsocketClose, Unauthorized, Forbidden @@ -68,14 +69,19 @@ class GatewayWebsocket: self.encoder, self.decoder = encodings[encoding] - async def send(self, payload: dict): - """Send a payload to the websocket""" + async def send(self, payload: Dict[str, Any]): + """Send a payload to the websocket. + + This function accounts for the zlib-stream + transport method used by Discord. + """ log.debug('Sending {}', pprint.pformat(payload)) encoded = self.encoder(payload) if not isinstance(encoded, bytes): encoded = encoded.encode() + print(self.wsp.compress) if self.wsp.compress == 'zlib-stream': data1 = self.wsp.zctx.compress(encoded) data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) @@ -83,7 +89,7 @@ class GatewayWebsocket: await self.ws.send(data1 + data2) else: # TODO: pure zlib - await self.ws.send(encoded) + await self.ws.send(encoded.decode()) async def send_hello(self): """Send the OP 10 Hello packet over the websocket.""" @@ -97,16 +103,19 @@ class GatewayWebsocket: } }) - async def dispatch(self, event, data): + async def dispatch(self, event: str, data: Any): """Dispatch an event to the websocket.""" self.state.seq += 1 - await self.send({ + payload = { 'op': OP.DISPATCH, 't': event.upper(), 's': self.state.seq, 'd': data, - }) + } + + self.state.store[self.state.seq] = payload + await self.send(payload) async def _make_guild_list(self) -> List[int]: # TODO: This function does not account for sharding. @@ -129,7 +138,7 @@ class GatewayWebsocket: for row in guild_ids ] - async def guild_dispatch(self, unavailable_guilds: List[dict]): + async def guild_dispatch(self, unavailable_guilds: List[Dict[str, Any]]): """Dispatch GUILD_CREATE information.""" # Users don't get asynchronous guild dispatching. @@ -237,10 +246,11 @@ class GatewayWebsocket: if current_shard > shard_count: raise InvalidShard('Shard count > Total shards') - async def handle_1(self, payload: dict): + async def handle_1(self, payload: Dict[str, Any]): + """Handle OP 1 Heartbeat packets.""" pass - async def handle_2(self, payload: dict): + async def handle_2(self, payload: Dict[str, Any]): """Handle the OP 2 Identify packet.""" data = payload['d'] try: @@ -282,6 +292,60 @@ class GatewayWebsocket: self.ext.state_manager.insert(self.state) await self.dispatch_ready() + async def handle_3(self, payload: Dict[str, Any]): + """Handle OP 3 Status Update.""" + pass + + async def handle_6(self, payload: Dict[str, Any]): + """Handle OP 6 Resume.""" + data = payload['d'] + + try: + token, sess_id, seq = data['token'], \ + data['session_id'], data['seq'] + except KeyError: + raise DecodeError('Invalid resume payload') + + try: + user_id = await raw_token_check(token, self.ext.db) + except (Unauthorized, Forbidden): + raise WebsocketClose(4004, 'Invalid token') + + try: + state = self.ext.state_manager.fetch(user_id, sess_id) + except KeyError: + return await self.send({ + 'op': 9, + 'd': False, + }) + + if seq > state.seq: + raise WebsocketClose(4007, 'Invalid seq') + + # check if a websocket isnt on that state already + if state.ws is not None: + log.info('Resuming failed, websocket already connected') + return await self.send({ + 'op': 9, + 'd': False, + }) + + # relink this connection + self.state = state + state.ws = self + + # TODO: resend payloads + + await self.dispatch('RESUMED', {}) + + async def handle_12(self, payload: Dict[str, Any]): + """Handle OP 12 Guild Sync.""" + data = payload['d'] + + for _guild_id in data: + # check if user in guild + pass + async def process_message(self, payload): """Process a single message coming in from the client.""" try: @@ -292,7 +356,7 @@ class GatewayWebsocket: try: handler = getattr(self, f'handle_{op_code}') except AttributeError: - raise UnknownOPCode('Bad OP code') + raise UnknownOPCode(f'Bad OP code: {op_code}') await handler(payload) @@ -316,7 +380,12 @@ class GatewayWebsocket: try: await self.send_hello() await self.listen_messages() + except websockets.exceptions.ConnectionClosed as err: + log.warning('Client closed, state={}, err={}', self.state, err) except WebsocketClose as err: log.warning('closed a client, state={} err={}', self.state, err) await self.ws.close(code=err.code, reason=err.reason) + finally: + if self.state: + self.state.ws = None