mirror of https://gitlab.com/litecord/litecord.git
test_websocket: add test for zlib stream
This commit is contained in:
parent
f792769656
commit
c85f0806c7
|
|
@ -18,6 +18,8 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
"""
|
||||
|
||||
import json
|
||||
import zlib
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
|
|
@ -25,15 +27,42 @@ import websockets
|
|||
from litecord.gateway.opcodes import OP
|
||||
from litecord.gateway.websocket import decode_etf
|
||||
|
||||
|
||||
async def _json(conn):
|
||||
frame = await conn.recv()
|
||||
return json.loads(frame)
|
||||
# Z_SYNC_FLUSH suffix
|
||||
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
||||
|
||||
|
||||
async def _etf(conn):
|
||||
frame = await conn.recv()
|
||||
return decode_etf(frame)
|
||||
async def _recv(conn, *, zlib_stream: bool):
|
||||
if zlib_stream:
|
||||
try:
|
||||
conn._zlib_context
|
||||
except AttributeError:
|
||||
conn._zlib_context = zlib.decompressobj()
|
||||
|
||||
# inspired by
|
||||
# https://discord.com/developers/docs/topics/gateway#transport-compression-transport-compression-example
|
||||
zlib_buffer = bytearray()
|
||||
while True:
|
||||
# keep receiving frames until we find the zlib prefix inside
|
||||
msg = await conn.recv()
|
||||
zlib_buffer.extend(msg)
|
||||
if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX:
|
||||
continue
|
||||
|
||||
# NOTE: the message is utf-8 encoded.
|
||||
msg = conn._zlib_context.decompress(zlib_buffer)
|
||||
return msg
|
||||
else:
|
||||
return await conn.recv()
|
||||
|
||||
|
||||
async def _json(conn, *, zlib_stream: bool = False):
|
||||
data = await _recv(conn, zlib_stream=zlib_stream)
|
||||
return json.loads(data)
|
||||
|
||||
|
||||
async def _etf(conn, *, zlib_stream: bool = False):
|
||||
data = await _recv(conn, zlib_stream=zlib_stream)
|
||||
return decode_etf(data)
|
||||
|
||||
|
||||
async def _json_send(conn, data):
|
||||
|
|
@ -49,8 +78,8 @@ async def _close(conn):
|
|||
await conn.close(1000, "test end")
|
||||
|
||||
|
||||
async def extract_and_verify_ready(conn):
|
||||
ready = await _json(conn)
|
||||
async def extract_and_verify_ready(conn, **kwargs):
|
||||
ready = await _json(conn, **kwargs)
|
||||
assert ready["op"] == OP.DISPATCH
|
||||
assert ready["t"] == "READY"
|
||||
|
||||
|
|
@ -78,7 +107,9 @@ async def get_gw(test_cli, version: int) -> str:
|
|||
return gw_json["url"]
|
||||
|
||||
|
||||
async def gw_start(test_cli, *, version: int = 6, etf=False):
|
||||
async def gw_start(
|
||||
test_cli, *, version: int = 6, etf=False, compress: Optional[str] = None
|
||||
):
|
||||
"""Start a websocket connection"""
|
||||
gw_url = await get_gw(test_cli, version)
|
||||
|
||||
|
|
@ -87,7 +118,8 @@ async def gw_start(test_cli, *, version: int = 6, etf=False):
|
|||
else:
|
||||
gw_url = f"{gw_url}?v={version}&encoding=json"
|
||||
|
||||
return await websockets.connect(gw_url)
|
||||
compress = f"&compress={compress}" if compress else ""
|
||||
return await websockets.connect(f"{gw_url}{compress}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -318,3 +350,18 @@ async def test_ready_bot(test_cli_bot):
|
|||
await extract_and_verify_ready(conn)
|
||||
finally:
|
||||
await _close(conn)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ready_bot_zlib_stream(test_cli_bot):
|
||||
conn = await gw_start(test_cli_bot.cli, compress="zlib-stream")
|
||||
await _json(conn, zlib_stream=True) # ignore hello
|
||||
await _json_send(
|
||||
conn,
|
||||
{"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}},
|
||||
)
|
||||
|
||||
try:
|
||||
await extract_and_verify_ready(conn, zlib_stream=True)
|
||||
finally:
|
||||
await _close(conn)
|
||||
|
|
|
|||
Loading…
Reference in New Issue