From c85f0806c7628ef62ec0bd4286ed6e47865f817a Mon Sep 17 00:00:00 2001 From: Luna Date: Tue, 25 Jan 2022 23:47:33 -0300 Subject: [PATCH] test_websocket: add test for zlib stream --- tests/test_websocket.py | 69 ++++++++++++++++++++++++++++++++++------- 1 file changed, 58 insertions(+), 11 deletions(-) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 166ad28..0d12787 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -18,6 +18,8 @@ along with this program. If not, see . """ import json +import zlib +from typing import Optional import pytest import websockets @@ -25,15 +27,42 @@ import websockets from litecord.gateway.opcodes import OP from litecord.gateway.websocket import decode_etf - -async def _json(conn): - frame = await conn.recv() - return json.loads(frame) +# Z_SYNC_FLUSH suffix +ZLIB_SUFFIX = b"\x00\x00\xff\xff" -async def _etf(conn): - frame = await conn.recv() - return decode_etf(frame) +async def _recv(conn, *, zlib_stream: bool): + if zlib_stream: + try: + conn._zlib_context + except AttributeError: + conn._zlib_context = zlib.decompressobj() + + # inspired by + # https://discord.com/developers/docs/topics/gateway#transport-compression-transport-compression-example + zlib_buffer = bytearray() + while True: + # keep receiving frames until we find the zlib prefix inside + msg = await conn.recv() + zlib_buffer.extend(msg) + if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX: + continue + + # NOTE: the message is utf-8 encoded. + msg = conn._zlib_context.decompress(zlib_buffer) + return msg + else: + return await conn.recv() + + +async def _json(conn, *, zlib_stream: bool = False): + data = await _recv(conn, zlib_stream=zlib_stream) + return json.loads(data) + + +async def _etf(conn, *, zlib_stream: bool = False): + data = await _recv(conn, zlib_stream=zlib_stream) + return decode_etf(data) async def _json_send(conn, data): @@ -49,8 +78,8 @@ async def _close(conn): await conn.close(1000, "test end") -async def extract_and_verify_ready(conn): - ready = await _json(conn) +async def extract_and_verify_ready(conn, **kwargs): + ready = await _json(conn, **kwargs) assert ready["op"] == OP.DISPATCH assert ready["t"] == "READY" @@ -78,7 +107,9 @@ async def get_gw(test_cli, version: int) -> str: return gw_json["url"] -async def gw_start(test_cli, *, version: int = 6, etf=False): +async def gw_start( + test_cli, *, version: int = 6, etf=False, compress: Optional[str] = None +): """Start a websocket connection""" gw_url = await get_gw(test_cli, version) @@ -87,7 +118,8 @@ async def gw_start(test_cli, *, version: int = 6, etf=False): else: gw_url = f"{gw_url}?v={version}&encoding=json" - return await websockets.connect(gw_url) + compress = f"&compress={compress}" if compress else "" + return await websockets.connect(f"{gw_url}{compress}") @pytest.mark.asyncio @@ -318,3 +350,18 @@ async def test_ready_bot(test_cli_bot): await extract_and_verify_ready(conn) finally: await _close(conn) + + +@pytest.mark.asyncio +async def test_ready_bot_zlib_stream(test_cli_bot): + conn = await gw_start(test_cli_bot.cli, compress="zlib-stream") + await _json(conn, zlib_stream=True) # ignore hello + await _json_send( + conn, + {"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}}, + ) + + try: + await extract_and_verify_ready(conn, zlib_stream=True) + finally: + await _close(conn)