diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 0d12787..f18c7fd 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -19,10 +19,23 @@ along with this program. If not, see . import json import zlib +import asyncio +import urllib.parse +import collections from typing import Optional import pytest import websockets +from logbook import Logger +from wsproto import WSConnection, ConnectionType +from wsproto.events import ( + Request, + Message, + AcceptConnection, + CloseConnection, + Ping, + Pong, +) from litecord.gateway.opcodes import OP from litecord.gateway.websocket import decode_etf @@ -31,6 +44,92 @@ from litecord.gateway.websocket import decode_etf ZLIB_SUFFIX = b"\x00\x00\xff\xff" +log = Logger("test_websocket") + +RcvdWrapper = collections.namedtuple("RcvdWrapper", "code reason") + + +class AsyncWebsocket: + """websockets-compatible websocket object""" + + def __init__(self, url): + self.url = url + self.ws = WSConnection(ConnectionType.CLIENT) + self.reader, self.writer = None, None + + async def send(self, data): + assert self.writer is not None + + # wrap all strings in Message + if isinstance(data, str): + data = Message(data=data) + + log.debug("sending {} event", type(data)) + + self.writer.write(self.ws.send(data)) + await self.writer.drain() + + async def recv(self, *, expect=Message, message_str: bool = True): + in_data = await self.reader.read(4096) + if not in_data: + log.info("connection closed (no data)") + self.ws.receive_data(None) + else: + log.debug("received {} bytes", len(in_data)) + self.ws.receive_data(in_data) + + # if we get a ping, reply with pong immediately + # and fetch the next event + event = next(self.ws.events()) + if isinstance(event, Ping): + await self.send(Pong()) + event = next(self.ws.events()) + + if isinstance(event, CloseConnection): + raise websockets.ConnectionClosed( + RcvdWrapper(event.code, event.reason), None + ) + + if expect is not None and not isinstance(event, expect): + raise AssertionError( + f"Expected {expect!r} websocket event, got {type(event)!r}" + ) + + # this keeps compatibility with code written for aaugustin/websockets + if expect is Message and message_str: + return event.data + + return event + + async def close(self, close_code: int, close_reason: str): + log.info("closing connection") + event = CloseConnection(code=close_code, reason=close_reason) + await self.send(event) + self.writer.close() + await self.writer.wait_closed() + self.ws.receive_data(None) + + async def connect(self): + parsed = urllib.parse.urlparse(self.url) + if parsed.scheme == "wss": + port = 443 + elif parsed.scheme == "ws": + port = 80 + else: + raise AssertionError("Invalid url scheme") + + host, *rest = parsed.netloc.split(":") + if rest: + port = rest[0] + + log.info("connecting to {!r} {}", host, port) + self.reader, self.writer = await asyncio.open_connection(host, port) + path = parsed.path or "/" + target = f"{path}?{parsed.query}" if parsed.query else path + await self.send(Request(host=parsed.netloc, target=target)) + await self.recv(expect=AcceptConnection) + + async def _recv(conn, *, zlib_stream: bool): if zlib_stream: try: @@ -43,11 +142,16 @@ async def _recv(conn, *, zlib_stream: bool): zlib_buffer = bytearray() while True: # keep receiving frames until we find the zlib prefix inside - msg = await conn.recv() - zlib_buffer.extend(msg) - if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX: + # we set message_str to false so that we get the entire event + # instead of only data + event = await conn.recv(message_str=False) + zlib_buffer.extend(event.data) + if not event.message_finished: continue + if len(zlib_buffer) < 4 or zlib_buffer[-4:] != ZLIB_SUFFIX: + raise RuntimeError("Finished compressed message without ZLIB suffix") + # NOTE: the message is utf-8 encoded. msg = conn._zlib_context.decompress(zlib_buffer) return msg @@ -119,7 +223,10 @@ async def gw_start( gw_url = f"{gw_url}?v={version}&encoding=json" compress = f"&compress={compress}" if compress else "" - return await websockets.connect(f"{gw_url}{compress}") + + ws = AsyncWebsocket(f"{gw_url}{compress}") + await ws.connect() + return ws @pytest.mark.asyncio