From 3b87a17477858b4fa5f6d3963973466fd3bb7d35 Mon Sep 17 00:00:00 2001 From: luna Date: Sat, 29 Jan 2022 23:38:21 +0000 Subject: [PATCH] tests: add websockets->wsproto translation layer --- poetry.lock | 2 +- pyproject.toml | 1 + tests/test_websocket.py | 134 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 130 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index 3c2ecb0..00b78f1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -424,7 +424,7 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "56d99717b6f3a32560d33ee9efdb77c7cf59f2447f1ff626c8dabb1261dc30c7" +content-hash = "135c208a1dd82e1f358357e6ce26802d134e57f6fb8966d634c8b200c9730d96" [metadata.files] aiofiles = [ diff --git a/pyproject.toml b/pyproject.toml index 75297e3..31ce245 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ pillow = "^8.3.2" aiohttp = "^3.7.4" zstandard = "^0.15.2" winter = {git = "https://gitlab.com/elixire/winter"} +wsproto = "^1.0.0" diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 0d12787..ee8d3fe 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -19,10 +19,23 @@ along with this program. If not, see . import json import zlib +import asyncio +import urllib.parse +import collections from typing import Optional import pytest import websockets +from logbook import Logger +from wsproto import WSConnection, ConnectionType +from wsproto.connection import ConnectionState +from wsproto.events import ( + Request, + Message, + AcceptConnection, + CloseConnection, + Ping, +) from litecord.gateway.opcodes import OP from litecord.gateway.websocket import decode_etf @@ -31,6 +44,109 @@ from litecord.gateway.websocket import decode_etf ZLIB_SUFFIX = b"\x00\x00\xff\xff" +log = Logger("test_websocket") + +RcvdWrapper = collections.namedtuple("RcvdWrapper", "code reason") + + +class AsyncWebsocket: + """websockets-compatible websocket object""" + + def __init__(self, url): + self.url = url + self.ws = WSConnection(ConnectionType.CLIENT) + self.reader, self.writer = None, None + + async def send(self, data): + assert self.writer is not None + + # wrap all strings in Message + if isinstance(data, str): + data = Message(data=data) + + log.debug("sending {} event", type(data)) + + self.writer.write(self.ws.send(data)) + await self.writer.drain() + + async def recv(self, *, expect=Message, process_event: bool = True): + + # this loop is only done so we reply to pings while also being + # able to receive any other event in the middle. + # + # CloseConnection does not lead us to reading other events, so + # that's why it's left out. + + while True: + # if there's already an unprocessed event we can try getting + # it from wsproto first + event = None + for event in self.ws.events(): + break + + if event is None: + data = await self.reader.read(4096) + assert data # We expect the WebSocket to be closed correctly + self.ws.receive_data(data) + continue + + # if we get a ping, reply with pong immediately + # and fetch the next event + if isinstance(event, Ping): + await self.send(event.response()) + continue + + break + + if isinstance(event, CloseConnection): + assert self.ws.state is ConnectionState.REMOTE_CLOSING + await self.send(event.response()) + if process_event: + raise websockets.ConnectionClosed( + RcvdWrapper(event.code, event.reason), None + ) + + if expect is not None and not isinstance(event, expect): + raise AssertionError( + f"Expected {expect!r} websocket event, got {type(event)!r}" + ) + + # this keeps compatibility with code written for aaugustin/websockets + if expect is Message and process_event: + return event.data + + return event + + async def close(self, close_code: int, close_reason: str): + log.info("closing connection") + event = CloseConnection(code=close_code, reason=close_reason) + await self.send(event) + self.writer.close() + await self.writer.wait_closed() + self.ws.receive_data(None) + + async def connect(self): + parsed = urllib.parse.urlparse(self.url) + if parsed.scheme == "wss": + port = 443 + elif parsed.scheme == "ws": + port = 80 + else: + raise AssertionError("Invalid url scheme") + + host, *rest = parsed.netloc.split(":") + if rest: + port = rest[0] + + log.info("connecting to {!r} {}", host, port) + self.reader, self.writer = await asyncio.open_connection(host, port) + + path = parsed.path or "/" + target = f"{path}?{parsed.query}" if parsed.query else path + await self.send(Request(host=parsed.netloc, target=target)) + await self.recv(expect=AcceptConnection) + + async def _recv(conn, *, zlib_stream: bool): if zlib_stream: try: @@ -43,11 +159,16 @@ async def _recv(conn, *, zlib_stream: bool): zlib_buffer = bytearray() while True: # keep receiving frames until we find the zlib prefix inside - msg = await conn.recv() - zlib_buffer.extend(msg) - if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX: + # we set process_event to false so that we get the entire event + # instead of only data + event = await conn.recv(process_event=False) + zlib_buffer.extend(event.data) + if not event.message_finished: continue + if len(zlib_buffer) < 4 or zlib_buffer[-4:] != ZLIB_SUFFIX: + raise RuntimeError("Finished compressed message without ZLIB suffix") + # NOTE: the message is utf-8 encoded. msg = conn._zlib_context.decompress(zlib_buffer) return msg @@ -119,7 +240,10 @@ async def gw_start( gw_url = f"{gw_url}?v={version}&encoding=json" compress = f"&compress={compress}" if compress else "" - return await websockets.connect(f"{gw_url}{compress}") + + ws = AsyncWebsocket(f"{gw_url}{compress}") + await ws.connect() + return ws @pytest.mark.asyncio @@ -174,8 +298,6 @@ async def test_broken_identify(test_cli_user): raise AssertionError("Received a JSON message but expected close") except websockets.ConnectionClosed as exc: assert exc.code == 4002 - finally: - await _close(conn) @pytest.mark.asyncio