mirror of https://gitlab.com/litecord/litecord.git
gateway: move encoding to litecord.gateway.encoding
This commit is contained in:
parent
df7f2b1b21
commit
9142e26152
|
|
@ -0,0 +1,84 @@
|
||||||
|
"""
|
||||||
|
|
||||||
|
Litecord
|
||||||
|
Copyright (C) 2018-2019 Luna Mendes
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, version 3 of the License.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import earl
|
||||||
|
|
||||||
|
from litecord.utils import LitecordJSONEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def encode_json(payload) -> str:
|
||||||
|
"""Encode a given payload to JSON."""
|
||||||
|
return json.dumps(payload, separators=(',', ':'),
|
||||||
|
cls=LitecordJSONEncoder)
|
||||||
|
|
||||||
|
|
||||||
|
def decode_json(data: str):
|
||||||
|
"""Decode from JSON."""
|
||||||
|
return json.loads(data)
|
||||||
|
|
||||||
|
|
||||||
|
def encode_etf(payload) -> str:
|
||||||
|
"""Encode a payload to ETF (External Term Format).
|
||||||
|
|
||||||
|
This gives a JSON pass on the given payload (via calling encode_json and
|
||||||
|
then decode_json) because we may want to encode objects that can only be
|
||||||
|
encoded by LitecordJSONEncoder.
|
||||||
|
|
||||||
|
Earl-ETF does not give the same interface for extensibility, hence why we
|
||||||
|
do the pass.
|
||||||
|
"""
|
||||||
|
sanitized = encode_json(payload)
|
||||||
|
sanitized = decode_json(sanitized)
|
||||||
|
return earl.pack(sanitized)
|
||||||
|
|
||||||
|
|
||||||
|
def _etf_decode_dict(data):
|
||||||
|
"""Decode a given dictionary."""
|
||||||
|
# NOTE: this is very slow.
|
||||||
|
|
||||||
|
if isinstance(data, bytes):
|
||||||
|
return data.decode()
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
return data
|
||||||
|
|
||||||
|
_copy = dict(data)
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
for key in _copy.keys():
|
||||||
|
# assuming key is bytes rn.
|
||||||
|
new_k = key.decode()
|
||||||
|
|
||||||
|
# maybe nested dicts, so...
|
||||||
|
result[new_k] = _etf_decode_dict(data[key])
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def decode_etf(data: bytes):
|
||||||
|
"""Decode data in ETF to any."""
|
||||||
|
res = earl.unpack(data)
|
||||||
|
|
||||||
|
if isinstance(res, bytes):
|
||||||
|
return data.decode()
|
||||||
|
|
||||||
|
if isinstance(res, dict):
|
||||||
|
return _etf_decode_dict(res)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
@ -50,7 +50,7 @@ async def websocket_handler(app, ws, url):
|
||||||
except (KeyError, IndexError):
|
except (KeyError, IndexError):
|
||||||
gw_compress = None
|
gw_compress = None
|
||||||
|
|
||||||
if gw_compress and gw_compress not in ('zlib-stream',):
|
if gw_compress and gw_compress not in ('zlib-stream', 'zstd-stream'):
|
||||||
return await ws.close(1000, 'Invalid gateway compress')
|
return await ws.close(1000, 'Invalid gateway compress')
|
||||||
|
|
||||||
gws = GatewayWebsocket(
|
gws = GatewayWebsocket(
|
||||||
|
|
|
||||||
|
|
@ -21,19 +21,17 @@ import collections
|
||||||
import asyncio
|
import asyncio
|
||||||
import pprint
|
import pprint
|
||||||
import zlib
|
import zlib
|
||||||
import json
|
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
import earl
|
|
||||||
|
|
||||||
from litecord.auth import raw_token_check
|
from litecord.auth import raw_token_check
|
||||||
from litecord.enums import RelationshipType, ChannelType
|
from litecord.enums import RelationshipType, ChannelType
|
||||||
from litecord.schemas import validate, GW_STATUS_UPDATE
|
from litecord.schemas import validate, GW_STATUS_UPDATE
|
||||||
from litecord.utils import (
|
from litecord.utils import (
|
||||||
task_wrapper, LitecordJSONEncoder, yield_chunks
|
task_wrapper, yield_chunks
|
||||||
)
|
)
|
||||||
from litecord.permissions import get_permissions
|
from litecord.permissions import get_permissions
|
||||||
|
|
||||||
|
|
@ -46,6 +44,9 @@ from litecord.errors import (
|
||||||
from litecord.gateway.errors import (
|
from litecord.gateway.errors import (
|
||||||
DecodeError, UnknownOPCode, InvalidShard, ShardingRequired
|
DecodeError, UnknownOPCode, InvalidShard, ShardingRequired
|
||||||
)
|
)
|
||||||
|
from litecord.gateway.encoding import (
|
||||||
|
encode_json, decode_json, encode_etf, decode_etf
|
||||||
|
)
|
||||||
|
|
||||||
from litecord.storage import int_
|
from litecord.storage import int_
|
||||||
|
|
||||||
|
|
@ -64,67 +65,6 @@ WebsocketObjects = collections.namedtuple(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def encode_json(payload) -> str:
|
|
||||||
"""Encode a given payload to JSON."""
|
|
||||||
return json.dumps(payload, separators=(',', ':'),
|
|
||||||
cls=LitecordJSONEncoder)
|
|
||||||
|
|
||||||
|
|
||||||
def decode_json(data: str):
|
|
||||||
"""Decode from JSON."""
|
|
||||||
return json.loads(data)
|
|
||||||
|
|
||||||
|
|
||||||
def encode_etf(payload) -> str:
|
|
||||||
"""Encode a payload to ETF (External Term Format).
|
|
||||||
|
|
||||||
This gives a JSON pass on the given payload (via calling encode_json and
|
|
||||||
then decode_json) because we may want to encode objects that can only be
|
|
||||||
encoded by LitecordJSONEncoder.
|
|
||||||
|
|
||||||
Earl-ETF does not give the same interface for extensibility, hence why we
|
|
||||||
do the pass.
|
|
||||||
"""
|
|
||||||
sanitized = encode_json(payload)
|
|
||||||
sanitized = decode_json(sanitized)
|
|
||||||
return earl.pack(sanitized)
|
|
||||||
|
|
||||||
|
|
||||||
def _etf_decode_dict(data):
|
|
||||||
"""Decode a given dictionary."""
|
|
||||||
# NOTE: this is very slow.
|
|
||||||
|
|
||||||
if isinstance(data, bytes):
|
|
||||||
return data.decode()
|
|
||||||
|
|
||||||
if not isinstance(data, dict):
|
|
||||||
return data
|
|
||||||
|
|
||||||
_copy = dict(data)
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
for key in _copy.keys():
|
|
||||||
# assuming key is bytes rn.
|
|
||||||
new_k = key.decode()
|
|
||||||
|
|
||||||
# maybe nested dicts, so...
|
|
||||||
result[new_k] = _etf_decode_dict(data[key])
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def decode_etf(data: bytes):
|
|
||||||
"""Decode data in ETF to any."""
|
|
||||||
res = earl.unpack(data)
|
|
||||||
|
|
||||||
if isinstance(res, bytes):
|
|
||||||
return data.decode()
|
|
||||||
|
|
||||||
if isinstance(res, dict):
|
|
||||||
return _etf_decode_dict(res)
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
class GatewayWebsocket:
|
class GatewayWebsocket:
|
||||||
"""Main gateway websocket logic."""
|
"""Main gateway websocket logic."""
|
||||||
|
|
||||||
|
|
@ -210,6 +150,9 @@ class GatewayWebsocket:
|
||||||
await self._chunked_send(data1, 1024)
|
await self._chunked_send(data1, 1024)
|
||||||
await self._chunked_send(data2, 1024)
|
await self._chunked_send(data2, 1024)
|
||||||
|
|
||||||
|
async def _zstd_stream_send(self, encoded):
|
||||||
|
pass
|
||||||
|
|
||||||
async def send(self, payload: Dict[str, Any]):
|
async def send(self, payload: Dict[str, Any]):
|
||||||
"""Send a payload to the websocket.
|
"""Send a payload to the websocket.
|
||||||
|
|
||||||
|
|
@ -233,6 +176,8 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
if self.wsp.compress == 'zlib-stream':
|
if self.wsp.compress == 'zlib-stream':
|
||||||
await self._zlib_stream_send(encoded)
|
await self._zlib_stream_send(encoded)
|
||||||
|
elif self.wsp.compress == 'zstd-stream':
|
||||||
|
await self._zstd_stream_send(encoded)
|
||||||
elif self.state and self.state.compress and len(encoded) > 1024:
|
elif self.state and self.state.compress and len(encoded) > 1024:
|
||||||
# TODO: should we only compress on >1KB packets? or maybe we
|
# TODO: should we only compress on >1KB packets? or maybe we
|
||||||
# should do all?
|
# should do all?
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue