tests: add websockets->wsproto translation layer

see #139
This commit is contained in:
Luna 2022-01-28 00:00:45 -03:00
parent f3979221ad
commit 0faae9fafd
1 changed files with 111 additions and 4 deletions

View File

@ -19,10 +19,23 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import json import json
import zlib import zlib
import asyncio
import urllib.parse
import collections
from typing import Optional from typing import Optional
import pytest import pytest
import websockets import websockets
from logbook import Logger
from wsproto import WSConnection, ConnectionType
from wsproto.events import (
Request,
Message,
AcceptConnection,
CloseConnection,
Ping,
Pong,
)
from litecord.gateway.opcodes import OP from litecord.gateway.opcodes import OP
from litecord.gateway.websocket import decode_etf from litecord.gateway.websocket import decode_etf
@ -31,6 +44,92 @@ from litecord.gateway.websocket import decode_etf
ZLIB_SUFFIX = b"\x00\x00\xff\xff" ZLIB_SUFFIX = b"\x00\x00\xff\xff"
log = Logger("test_websocket")
RcvdWrapper = collections.namedtuple("RcvdWrapper", "code reason")
class AsyncWebsocket:
"""websockets-compatible websocket object"""
def __init__(self, url):
self.url = url
self.ws = WSConnection(ConnectionType.CLIENT)
self.reader, self.writer = None, None
async def send(self, data):
assert self.writer is not None
# wrap all strings in Message
if isinstance(data, str):
data = Message(data=data)
log.debug("sending {} event", type(data))
self.writer.write(self.ws.send(data))
await self.writer.drain()
async def recv(self, *, expect=Message, message_str: bool = True):
in_data = await self.reader.read(4096)
if not in_data:
log.info("connection closed (no data)")
self.ws.receive_data(None)
else:
log.debug("received {} bytes", len(in_data))
self.ws.receive_data(in_data)
# if we get a ping, reply with pong immediately
# and fetch the next event
event = next(self.ws.events())
if isinstance(event, Ping):
await self.send(Pong())
event = next(self.ws.events())
if isinstance(event, CloseConnection):
raise websockets.ConnectionClosed(
RcvdWrapper(event.code, event.reason), None
)
if expect is not None and not isinstance(event, expect):
raise AssertionError(
f"Expected {expect!r} websocket event, got {type(event)!r}"
)
# this keeps compatibility with code written for aaugustin/websockets
if expect is Message and message_str:
return event.data
return event
async def close(self, close_code: int, close_reason: str):
log.info("closing connection")
event = CloseConnection(code=close_code, reason=close_reason)
await self.send(event)
self.writer.close()
await self.writer.wait_closed()
self.ws.receive_data(None)
async def connect(self):
parsed = urllib.parse.urlparse(self.url)
if parsed.scheme == "wss":
port = 443
elif parsed.scheme == "ws":
port = 80
else:
raise AssertionError("Invalid url scheme")
host, *rest = parsed.netloc.split(":")
if rest:
port = rest[0]
log.info("connecting to {!r} {}", host, port)
self.reader, self.writer = await asyncio.open_connection(host, port)
path = parsed.path or "/"
target = f"{path}?{parsed.query}" if parsed.query else path
await self.send(Request(host=parsed.netloc, target=target))
await self.recv(expect=AcceptConnection)
async def _recv(conn, *, zlib_stream: bool): async def _recv(conn, *, zlib_stream: bool):
if zlib_stream: if zlib_stream:
try: try:
@ -43,11 +142,16 @@ async def _recv(conn, *, zlib_stream: bool):
zlib_buffer = bytearray() zlib_buffer = bytearray()
while True: while True:
# keep receiving frames until we find the zlib prefix inside # keep receiving frames until we find the zlib prefix inside
msg = await conn.recv() # we set message_str to false so that we get the entire event
zlib_buffer.extend(msg) # instead of only data
if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX: event = await conn.recv(message_str=False)
zlib_buffer.extend(event.data)
if not event.message_finished:
continue continue
if len(zlib_buffer) < 4 or zlib_buffer[-4:] != ZLIB_SUFFIX:
raise RuntimeError("Finished compressed message without ZLIB suffix")
# NOTE: the message is utf-8 encoded. # NOTE: the message is utf-8 encoded.
msg = conn._zlib_context.decompress(zlib_buffer) msg = conn._zlib_context.decompress(zlib_buffer)
return msg return msg
@ -119,7 +223,10 @@ async def gw_start(
gw_url = f"{gw_url}?v={version}&encoding=json" gw_url = f"{gw_url}?v={version}&encoding=json"
compress = f"&compress={compress}" if compress else "" compress = f"&compress={compress}" if compress else ""
return await websockets.connect(f"{gw_url}{compress}")
ws = AsyncWebsocket(f"{gw_url}{compress}")
await ws.connect()
return ws
@pytest.mark.asyncio @pytest.mark.asyncio