diff --git a/litecord/auth.py b/litecord/auth.py index d766d90..fa8404b 100644 --- a/litecord/auth.py +++ b/litecord/auth.py @@ -48,4 +48,7 @@ async def token_check(): except KeyError: raise Unauthorized('No token provided') + if token.startswith('Bot '): + token = token.replace('Bot ', '') + return await raw_token_check(token) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index d8dc530..4becb79 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -1,5 +1,7 @@ import json import collections +import pprint +import zlib from typing import List import earl @@ -15,7 +17,7 @@ from .state import GatewayState log = Logger(__name__) WebsocketProperties = collections.namedtuple( - 'WebsocketProperties', 'v encoding compress' + 'WebsocketProperties', 'v encoding compress zctx' ) WebsocketObjects = collections.namedtuple( @@ -49,7 +51,8 @@ class GatewayWebsocket: self.wsp = WebsocketProperties(kwargs.get('v'), kwargs.get('encoding', 'json'), - kwargs.get('compress', None)) + kwargs.get('compress', None), + zlib.compressobj()) self.state = None @@ -67,11 +70,20 @@ class GatewayWebsocket: async def send(self, payload: dict): """Send a payload to the websocket""" + log.debug('Sending {}', pprint.pformat(payload)) encoded = self.encoder(payload) - # TODO: compression + if not isinstance(encoded, bytes): + encoded = encoded.encode() - await self.ws.send(encoded) + if self.wsp.compress == 'zlib-stream': + data1 = self.wsp.zctx.compress(encoded) + data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH) + + await self.ws.send(data1 + data2) + else: + # TODO: pure zlib + await self.ws.send(encoded) async def send_hello(self): """Send the OP 10 Hello packet over the websocket.""" @@ -173,8 +185,11 @@ class GatewayWebsocket: if current_shard > shard_count: raise InvalidShard('Shard count > Total shards') - async def handle_0(self, payload: dict): - """Handle the OP 0 Identify packet.""" + async def handle_1(self, payload: dict): + pass + + async def handle_2(self, payload: dict): + """Handle the OP 2 Identify packet.""" data = payload['d'] try: token, properties = data['token'], data['properties'] @@ -237,6 +252,10 @@ class GatewayWebsocket: raise DecodeError('Payload length exceeded') payload = self.decoder(message) + + pretty_printed = pprint.pformat(payload) + log.debug('received message: {}', pretty_printed) + await self.process_message(payload) async def run(self): diff --git a/litecord/storage.py b/litecord/storage.py index aac208f..f66952a 100644 --- a/litecord/storage.py +++ b/litecord/storage.py @@ -52,10 +52,15 @@ class Storage: # (No, changing the types on the db wouldn't be nice) drow['id'] = str(drow['id']) drow['owner_id'] = str(drow['owner_id']) - drow['afk_channel_id'] = str(drow['afk_channel_id']) - drow['embed_channel_id'] = str(drow['embed_channel_id']) - drow['widget_channel_id'] = str(drow['widget_channel_id']) - drow['system_channel_id'] = str(drow['system_channel_id']) + drow['afk_channel_id'] = str(drow['afk_channel_id']) \ + if drow['afk_channel_id'] else None + drow['embed_channel_id'] = str(drow['embed_channel_id']) \ + if drow['embed_channel_id'] else None + + drow['widget_channel_id'] = str(drow['widget_channel_id']) \ + if drow['widget_channel_id'] else None + drow['system_channel_id'] = str(drow['system_channel_id']) \ + if drow['system_channel_id'] else None return {**drow, **{ 'roles': [], @@ -185,6 +190,7 @@ class Storage: """, row['id']) res = await self._channels_extra(dict(row), ctype) + res['type'] = ctype # type is a SQL keyword, so we can't do # 'overwrite_type AS type' diff --git a/run.py b/run.py index 339bf3c..9a1f74b 100644 --- a/run.py +++ b/run.py @@ -6,6 +6,7 @@ import logbook import websockets from quart import Quart, g, jsonify from logbook import StreamHandler, Logger +from logbook.compat import redirect_logging import config from litecord.blueprints import gateway, auth, users, guilds @@ -19,6 +20,7 @@ from litecord.dispatcher import EventDispatcher handler = StreamHandler(sys.stdout, level=logbook.INFO) handler.push_application() log = Logger('litecord.boot') +redirect_logging() def make_app():