Merge branch 'refactor/wsproto-tests' into 'master'

tests: add websockets->wsproto translation layer

See merge request litecord/litecord!85
This commit is contained in:
luna 2022-01-29 02:54:31 +00:00
commit 001b15601e
3 changed files with 145 additions and 7 deletions

2
poetry.lock generated
View File

@ -424,7 +424,7 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "1.1"
python-versions = "^3.9"
content-hash = "56d99717b6f3a32560d33ee9efdb77c7cf59f2447f1ff626c8dabb1261dc30c7"
content-hash = "135c208a1dd82e1f358357e6ce26802d134e57f6fb8966d634c8b200c9730d96"
[metadata.files]
aiofiles = [

View File

@ -19,6 +19,7 @@ pillow = "^8.3.2"
aiohttp = "^3.7.4"
zstandard = "^0.15.2"
winter = {git = "https://gitlab.com/elixire/winter"}
wsproto = "^1.0.0"

View File

@ -19,10 +19,23 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import json
import zlib
import asyncio
import urllib.parse
import collections
from typing import Optional
import pytest
import websockets
from logbook import Logger
from wsproto import WSConnection, ConnectionType
from wsproto.connection import ConnectionState
from wsproto.events import (
Request,
Message,
AcceptConnection,
CloseConnection,
Ping,
)
from litecord.gateway.opcodes import OP
from litecord.gateway.websocket import decode_etf
@ -31,6 +44,124 @@ from litecord.gateway.websocket import decode_etf
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
log = Logger("test_websocket")
RcvdWrapper = collections.namedtuple("RcvdWrapper", "code reason")
class AsyncWebsocket:
"""websockets-compatible websocket object"""
def __init__(self, url):
self.url = url
self.ws = WSConnection(ConnectionType.CLIENT)
self.reader, self.writer = None, None
self.reader_task = None
self._events = asyncio.Queue()
async def send(self, data):
assert self.writer is not None
# wrap all strings in Message
if isinstance(data, str):
data = Message(data=data)
log.debug("sending {} event", type(data))
self.writer.write(self.ws.send(data))
await self.writer.drain()
async def _reader_loop_task(self):
# continuously read messages from the socket
# and fill up the _events queue with them
while True:
log.info("reading data")
in_data = await self.reader.read(4096)
if not in_data:
log.info("connection closed (no data)")
self.ws.receive_data(None)
else:
log.debug("received {} bytes", len(in_data))
self.ws.receive_data(in_data)
for event in self.ws.events():
log.debug("queued ws event {}", event)
await self._events.put(event)
# since we closed, we don't have to continue reading
if not in_data:
return
async def recv(self, *, expect=Message, process_event: bool = True):
# this loop is only done so we reply to pings while also being
# able to receive any other event in the middle.
#
# CloseConnection does not lead us to reading other events, so
# that's why it's left out.
while True:
# if we get a ping, reply with pong immediately
# and fetch the next event
event = await self._events.get()
log.debug("processing {}", event)
if isinstance(event, Ping):
await self.send(event.response())
continue
break
if isinstance(event, CloseConnection):
assert self.ws.state is ConnectionState.REMOTE_CLOSING
await self.send(event.response())
if process_event:
raise websockets.ConnectionClosed(
RcvdWrapper(event.code, event.reason), None
)
if expect is not None and not isinstance(event, expect):
raise AssertionError(
f"Expected {expect!r} websocket event, got {type(event)!r}"
)
# this keeps compatibility with code written for aaugustin/websockets
if expect is Message and process_event:
return event.data
return event
async def close(self, close_code: int, close_reason: str):
log.info("closing connection")
event = CloseConnection(code=close_code, reason=close_reason)
await self.send(event)
self.writer.close()
await self.writer.wait_closed()
self.ws.receive_data(None)
async def connect(self):
parsed = urllib.parse.urlparse(self.url)
if parsed.scheme == "wss":
port = 443
elif parsed.scheme == "ws":
port = 80
else:
raise AssertionError("Invalid url scheme")
host, *rest = parsed.netloc.split(":")
if rest:
port = rest[0]
log.info("connecting to {!r} {}", host, port)
self.reader, self.writer = await asyncio.open_connection(host, port)
self.reader_task = asyncio.create_task(self._reader_loop_task())
path = parsed.path or "/"
target = f"{path}?{parsed.query}" if parsed.query else path
await self.send(Request(host=parsed.netloc, target=target))
await self.recv(expect=AcceptConnection)
async def _recv(conn, *, zlib_stream: bool):
if zlib_stream:
try:
@ -43,11 +174,16 @@ async def _recv(conn, *, zlib_stream: bool):
zlib_buffer = bytearray()
while True:
# keep receiving frames until we find the zlib prefix inside
msg = await conn.recv()
zlib_buffer.extend(msg)
if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX:
# we set process_event to false so that we get the entire event
# instead of only data
event = await conn.recv(process_event=False)
zlib_buffer.extend(event.data)
if not event.message_finished:
continue
if len(zlib_buffer) < 4 or zlib_buffer[-4:] != ZLIB_SUFFIX:
raise RuntimeError("Finished compressed message without ZLIB suffix")
# NOTE: the message is utf-8 encoded.
msg = conn._zlib_context.decompress(zlib_buffer)
return msg
@ -119,7 +255,10 @@ async def gw_start(
gw_url = f"{gw_url}?v={version}&encoding=json"
compress = f"&compress={compress}" if compress else ""
return await websockets.connect(f"{gw_url}{compress}")
ws = AsyncWebsocket(f"{gw_url}{compress}")
await ws.connect()
return ws
@pytest.mark.asyncio
@ -174,8 +313,6 @@ async def test_broken_identify(test_cli_user):
raise AssertionError("Received a JSON message but expected close")
except websockets.ConnectionClosed as exc:
assert exc.code == 4002
finally:
await _close(conn)
@pytest.mark.asyncio