diff --git a/litecord/gateway/encoding.py b/litecord/gateway/encoding.py new file mode 100644 index 0000000..07957f6 --- /dev/null +++ b/litecord/gateway/encoding.py @@ -0,0 +1,84 @@ +""" + +Litecord +Copyright (C) 2018-2019 Luna Mendes + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . + +""" + +import json +import earl + +from litecord.utils import LitecordJSONEncoder + + +def encode_json(payload) -> str: + """Encode a given payload to JSON.""" + return json.dumps(payload, separators=(',', ':'), + cls=LitecordJSONEncoder) + + +def decode_json(data: str): + """Decode from JSON.""" + return json.loads(data) + + +def encode_etf(payload) -> str: + """Encode a payload to ETF (External Term Format). + + This gives a JSON pass on the given payload (via calling encode_json and + then decode_json) because we may want to encode objects that can only be + encoded by LitecordJSONEncoder. + + Earl-ETF does not give the same interface for extensibility, hence why we + do the pass. + """ + sanitized = encode_json(payload) + sanitized = decode_json(sanitized) + return earl.pack(sanitized) + + +def _etf_decode_dict(data): + """Decode a given dictionary.""" + # NOTE: this is very slow. + + if isinstance(data, bytes): + return data.decode() + + if not isinstance(data, dict): + return data + + _copy = dict(data) + result = {} + + for key in _copy.keys(): + # assuming key is bytes rn. + new_k = key.decode() + + # maybe nested dicts, so... + result[new_k] = _etf_decode_dict(data[key]) + + return result + +def decode_etf(data: bytes): + """Decode data in ETF to any.""" + res = earl.unpack(data) + + if isinstance(res, bytes): + return data.decode() + + if isinstance(res, dict): + return _etf_decode_dict(res) + + return res diff --git a/litecord/gateway/gateway.py b/litecord/gateway/gateway.py index f435d16..a213528 100644 --- a/litecord/gateway/gateway.py +++ b/litecord/gateway/gateway.py @@ -50,7 +50,7 @@ async def websocket_handler(app, ws, url): except (KeyError, IndexError): gw_compress = None - if gw_compress and gw_compress not in ('zlib-stream',): + if gw_compress and gw_compress not in ('zlib-stream', 'zstd-stream'): return await ws.close(1000, 'Invalid gateway compress') gws = GatewayWebsocket( diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 335cfd1..7283f17 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -21,19 +21,17 @@ import collections import asyncio import pprint import zlib -import json from typing import List, Dict, Any from random import randint import websockets from logbook import Logger -import earl from litecord.auth import raw_token_check from litecord.enums import RelationshipType, ChannelType from litecord.schemas import validate, GW_STATUS_UPDATE from litecord.utils import ( - task_wrapper, LitecordJSONEncoder, yield_chunks + task_wrapper, yield_chunks ) from litecord.permissions import get_permissions @@ -46,6 +44,9 @@ from litecord.errors import ( from litecord.gateway.errors import ( DecodeError, UnknownOPCode, InvalidShard, ShardingRequired ) +from litecord.gateway.encoding import ( + encode_json, decode_json, encode_etf, decode_etf +) from litecord.storage import int_ @@ -64,67 +65,6 @@ WebsocketObjects = collections.namedtuple( ) -def encode_json(payload) -> str: - """Encode a given payload to JSON.""" - return json.dumps(payload, separators=(',', ':'), - cls=LitecordJSONEncoder) - - -def decode_json(data: str): - """Decode from JSON.""" - return json.loads(data) - - -def encode_etf(payload) -> str: - """Encode a payload to ETF (External Term Format). - - This gives a JSON pass on the given payload (via calling encode_json and - then decode_json) because we may want to encode objects that can only be - encoded by LitecordJSONEncoder. - - Earl-ETF does not give the same interface for extensibility, hence why we - do the pass. - """ - sanitized = encode_json(payload) - sanitized = decode_json(sanitized) - return earl.pack(sanitized) - - -def _etf_decode_dict(data): - """Decode a given dictionary.""" - # NOTE: this is very slow. - - if isinstance(data, bytes): - return data.decode() - - if not isinstance(data, dict): - return data - - _copy = dict(data) - result = {} - - for key in _copy.keys(): - # assuming key is bytes rn. - new_k = key.decode() - - # maybe nested dicts, so... - result[new_k] = _etf_decode_dict(data[key]) - - return result - -def decode_etf(data: bytes): - """Decode data in ETF to any.""" - res = earl.unpack(data) - - if isinstance(res, bytes): - return data.decode() - - if isinstance(res, dict): - return _etf_decode_dict(res) - - return res - - class GatewayWebsocket: """Main gateway websocket logic.""" @@ -210,6 +150,9 @@ class GatewayWebsocket: await self._chunked_send(data1, 1024) await self._chunked_send(data2, 1024) + async def _zstd_stream_send(self, encoded): + pass + async def send(self, payload: Dict[str, Any]): """Send a payload to the websocket. @@ -233,6 +176,8 @@ class GatewayWebsocket: if self.wsp.compress == 'zlib-stream': await self._zlib_stream_send(encoded) + elif self.wsp.compress == 'zstd-stream': + await self._zstd_stream_send(encoded) elif self.state and self.state.compress and len(encoded) > 1024: # TODO: should we only compress on >1KB packets? or maybe we # should do all?