tests: add websockets->wsproto translation layer

see #139
This commit is contained in:
Luna 2022-01-28 00:00:45 -03:00
parent f3979221ad
commit 0faae9fafd
1 changed files with 111 additions and 4 deletions

View File

@ -19,10 +19,23 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import json
import zlib
import asyncio
import urllib.parse
import collections
from typing import Optional
import pytest
import websockets
from logbook import Logger
from wsproto import WSConnection, ConnectionType
from wsproto.events import (
Request,
Message,
AcceptConnection,
CloseConnection,
Ping,
Pong,
)
from litecord.gateway.opcodes import OP
from litecord.gateway.websocket import decode_etf
@ -31,6 +44,92 @@ from litecord.gateway.websocket import decode_etf
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
log = Logger("test_websocket")
RcvdWrapper = collections.namedtuple("RcvdWrapper", "code reason")
class AsyncWebsocket:
"""websockets-compatible websocket object"""
def __init__(self, url):
self.url = url
self.ws = WSConnection(ConnectionType.CLIENT)
self.reader, self.writer = None, None
async def send(self, data):
assert self.writer is not None
# wrap all strings in Message
if isinstance(data, str):
data = Message(data=data)
log.debug("sending {} event", type(data))
self.writer.write(self.ws.send(data))
await self.writer.drain()
async def recv(self, *, expect=Message, message_str: bool = True):
in_data = await self.reader.read(4096)
if not in_data:
log.info("connection closed (no data)")
self.ws.receive_data(None)
else:
log.debug("received {} bytes", len(in_data))
self.ws.receive_data(in_data)
# if we get a ping, reply with pong immediately
# and fetch the next event
event = next(self.ws.events())
if isinstance(event, Ping):
await self.send(Pong())
event = next(self.ws.events())
if isinstance(event, CloseConnection):
raise websockets.ConnectionClosed(
RcvdWrapper(event.code, event.reason), None
)
if expect is not None and not isinstance(event, expect):
raise AssertionError(
f"Expected {expect!r} websocket event, got {type(event)!r}"
)
# this keeps compatibility with code written for aaugustin/websockets
if expect is Message and message_str:
return event.data
return event
async def close(self, close_code: int, close_reason: str):
log.info("closing connection")
event = CloseConnection(code=close_code, reason=close_reason)
await self.send(event)
self.writer.close()
await self.writer.wait_closed()
self.ws.receive_data(None)
async def connect(self):
parsed = urllib.parse.urlparse(self.url)
if parsed.scheme == "wss":
port = 443
elif parsed.scheme == "ws":
port = 80
else:
raise AssertionError("Invalid url scheme")
host, *rest = parsed.netloc.split(":")
if rest:
port = rest[0]
log.info("connecting to {!r} {}", host, port)
self.reader, self.writer = await asyncio.open_connection(host, port)
path = parsed.path or "/"
target = f"{path}?{parsed.query}" if parsed.query else path
await self.send(Request(host=parsed.netloc, target=target))
await self.recv(expect=AcceptConnection)
async def _recv(conn, *, zlib_stream: bool):
if zlib_stream:
try:
@ -43,11 +142,16 @@ async def _recv(conn, *, zlib_stream: bool):
zlib_buffer = bytearray()
while True:
# keep receiving frames until we find the zlib prefix inside
msg = await conn.recv()
zlib_buffer.extend(msg)
if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX:
# we set message_str to false so that we get the entire event
# instead of only data
event = await conn.recv(message_str=False)
zlib_buffer.extend(event.data)
if not event.message_finished:
continue
if len(zlib_buffer) < 4 or zlib_buffer[-4:] != ZLIB_SUFFIX:
raise RuntimeError("Finished compressed message without ZLIB suffix")
# NOTE: the message is utf-8 encoded.
msg = conn._zlib_context.decompress(zlib_buffer)
return msg
@ -119,7 +223,10 @@ async def gw_start(
gw_url = f"{gw_url}?v={version}&encoding=json"
compress = f"&compress={compress}" if compress else ""
return await websockets.connect(f"{gw_url}{compress}")
ws = AsyncWebsocket(f"{gw_url}{compress}")
await ws.connect()
return ws
@pytest.mark.asyncio