gateway.websocket: do chunked sends on zlib stream

this should fix issues with big payloads being sent
as a single big websocket message to a client (potentially crashing it).

chunked sends split the payload into 1KB chunks that are each
sent through the websocket. clients are already supposed to handle
this behavior from the zlib-stream docs.

 - utils: add yield_chunks
This commit is contained in:
Luna 2018-12-10 01:44:44 -03:00
parent 5c38198137
commit 6872139ff6
3 changed files with 59 additions and 9 deletions

View File

@ -54,8 +54,6 @@ async def websocket_handler(app, ws, url):
if gw_compress and gw_compress not in ('zlib-stream',): if gw_compress and gw_compress not in ('zlib-stream',):
return await ws.close(1000, 'Invalid gateway compress') return await ws.close(1000, 'Invalid gateway compress')
print('encoding', gw_encoding, 'compression', gw_compress)
gws = GatewayWebsocket(ws, app, v=gw_version, gws = GatewayWebsocket(ws, app, v=gw_version,
encoding=gw_encoding, compress=gw_compress) encoding=gw_encoding, compress=gw_compress)
await gws.run() await gws.run()

View File

@ -32,7 +32,9 @@ import earl
from litecord.auth import raw_token_check from litecord.auth import raw_token_check
from litecord.enums import RelationshipType from litecord.enums import RelationshipType
from litecord.schemas import validate, GW_STATUS_UPDATE from litecord.schemas import validate, GW_STATUS_UPDATE
from litecord.utils import task_wrapper, LitecordJSONEncoder from litecord.utils import (
task_wrapper, LitecordJSONEncoder, yield_chunks
)
from litecord.permissions import get_permissions from litecord.permissions import get_permissions
from litecord.gateway.opcodes import OP from litecord.gateway.opcodes import OP
@ -41,7 +43,7 @@ from litecord.gateway.state import GatewayState
from litecord.errors import ( from litecord.errors import (
WebsocketClose, Unauthorized, Forbidden, BadRequest WebsocketClose, Unauthorized, Forbidden, BadRequest
) )
from .errors import ( from litecord.gateway.errors import (
DecodeError, UnknownOPCode, InvalidShard, ShardingRequired DecodeError, UnknownOPCode, InvalidShard, ShardingRequired
) )
@ -125,6 +127,8 @@ class GatewayWebsocket:
zlib.compressobj(), zlib.compressobj(),
{}) {})
log.debug('websocket properties: {!r}', self.wsp)
self.state = None self.state = None
self._set_encoders() self._set_encoders()
@ -139,6 +143,41 @@ class GatewayWebsocket:
self.encoder, self.decoder = encodings[encoding] self.encoder, self.decoder = encodings[encoding]
async def _chunked_send(self, data: bytes, chunk_size: int):
"""Split data in chunk_size-big chunks and send them
over the websocket."""
total_chunks = 0
for chunk in yield_chunks(data, chunk_size):
total_chunks += 1
await self.ws.send(chunk)
log.debug('zlib-stream: sent {} chunks', total_chunks)
async def _zlib_stream_send(self, encoded):
"""Sending a single payload across multiple compressed
websocket messages."""
# compress and flush (for the rest of compressed data + ZLIB_SUFFIX)
data1 = self.wsp.zctx.compress(encoded)
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
log.debug('zlib-stream: length {} -> compressed ({} + {})',
len(encoded), len(data1), len(data2))
# NOTE: the old approach was ws.send(data1 + data2).
# I changed this to a chunked send of data1 and data2
# because that can bring some problems to the network
# since we can be potentially sending a really big packet
# as a single message.
# clients should handle chunked sends (via detection
# of the ZLIB_SUFFIX suffix appended to data2), so
# this shouldn't being problems.
# TODO: the chunks are 1024 bytes, 1KB, is this good enough?
await self._chunked_send(data1, 1024)
await self._chunked_send(data2, 1024)
async def send(self, payload: Dict[str, Any]): async def send(self, payload: Dict[str, Any]):
"""Send a payload to the websocket. """Send a payload to the websocket.
@ -159,13 +198,11 @@ class GatewayWebsocket:
if not isinstance(encoded, bytes): if not isinstance(encoded, bytes):
encoded = encoded.encode() encoded = encoded.encode()
# handle zlib-stream, pure zlib or plain
if self.wsp.compress == 'zlib-stream': if self.wsp.compress == 'zlib-stream':
data1 = self.wsp.zctx.compress(encoded) await self._zlib_stream_send(encoded)
data2 = self.wsp.zctx.flush(zlib.Z_FULL_FLUSH)
await self.ws.send(data1 + data2)
elif self.state and self.state.compress and len(encoded) > 1024: elif self.state and self.state.compress and len(encoded) > 1024:
# TODO: should we only compress on >1KB packets? or maybe we
# should do all?
await self.ws.send(zlib.compress(encoded)) await self.ws.send(zlib.compress(encoded))
else: else:
await self.ws.send(encoded.decode()) await self.ws.send(encoded.decode())

View File

@ -158,3 +158,18 @@ async def pg_set_json(con):
decoder=json.loads, decoder=json.loads,
schema='pg_catalog' schema='pg_catalog'
) )
def yield_chunks(input_list: list, chunk_size: int):
"""Yield successive n-sized chunks from l.
Taken from https://stackoverflow.com/a/312464.
Modified to make linter happy (variable name changes,
typing, comments).
"""
# range accepts step param, so we use that to
# make the chunks
for idx in range(0, len(input_list), chunk_size):
yield input_list[idx:idx + chunk_size]