mirror of https://gitlab.com/litecord/litecord.git
tests: add websockets->wsproto translation layer
This commit is contained in:
parent
6ac705f838
commit
3b87a17477
|
|
@ -424,7 +424,7 @@ cffi = ["cffi (>=1.11)"]
|
|||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.9"
|
||||
content-hash = "56d99717b6f3a32560d33ee9efdb77c7cf59f2447f1ff626c8dabb1261dc30c7"
|
||||
content-hash = "135c208a1dd82e1f358357e6ce26802d134e57f6fb8966d634c8b200c9730d96"
|
||||
|
||||
[metadata.files]
|
||||
aiofiles = [
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ pillow = "^8.3.2"
|
|||
aiohttp = "^3.7.4"
|
||||
zstandard = "^0.15.2"
|
||||
winter = {git = "https://gitlab.com/elixire/winter"}
|
||||
wsproto = "^1.0.0"
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,10 +19,23 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|||
|
||||
import json
|
||||
import zlib
|
||||
import asyncio
|
||||
import urllib.parse
|
||||
import collections
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import websockets
|
||||
from logbook import Logger
|
||||
from wsproto import WSConnection, ConnectionType
|
||||
from wsproto.connection import ConnectionState
|
||||
from wsproto.events import (
|
||||
Request,
|
||||
Message,
|
||||
AcceptConnection,
|
||||
CloseConnection,
|
||||
Ping,
|
||||
)
|
||||
|
||||
from litecord.gateway.opcodes import OP
|
||||
from litecord.gateway.websocket import decode_etf
|
||||
|
|
@ -31,6 +44,109 @@ from litecord.gateway.websocket import decode_etf
|
|||
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
||||
|
||||
|
||||
log = Logger("test_websocket")
|
||||
|
||||
RcvdWrapper = collections.namedtuple("RcvdWrapper", "code reason")
|
||||
|
||||
|
||||
class AsyncWebsocket:
|
||||
"""websockets-compatible websocket object"""
|
||||
|
||||
def __init__(self, url):
|
||||
self.url = url
|
||||
self.ws = WSConnection(ConnectionType.CLIENT)
|
||||
self.reader, self.writer = None, None
|
||||
|
||||
async def send(self, data):
|
||||
assert self.writer is not None
|
||||
|
||||
# wrap all strings in Message
|
||||
if isinstance(data, str):
|
||||
data = Message(data=data)
|
||||
|
||||
log.debug("sending {} event", type(data))
|
||||
|
||||
self.writer.write(self.ws.send(data))
|
||||
await self.writer.drain()
|
||||
|
||||
async def recv(self, *, expect=Message, process_event: bool = True):
|
||||
|
||||
# this loop is only done so we reply to pings while also being
|
||||
# able to receive any other event in the middle.
|
||||
#
|
||||
# CloseConnection does not lead us to reading other events, so
|
||||
# that's why it's left out.
|
||||
|
||||
while True:
|
||||
# if there's already an unprocessed event we can try getting
|
||||
# it from wsproto first
|
||||
event = None
|
||||
for event in self.ws.events():
|
||||
break
|
||||
|
||||
if event is None:
|
||||
data = await self.reader.read(4096)
|
||||
assert data # We expect the WebSocket to be closed correctly
|
||||
self.ws.receive_data(data)
|
||||
continue
|
||||
|
||||
# if we get a ping, reply with pong immediately
|
||||
# and fetch the next event
|
||||
if isinstance(event, Ping):
|
||||
await self.send(event.response())
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
if isinstance(event, CloseConnection):
|
||||
assert self.ws.state is ConnectionState.REMOTE_CLOSING
|
||||
await self.send(event.response())
|
||||
if process_event:
|
||||
raise websockets.ConnectionClosed(
|
||||
RcvdWrapper(event.code, event.reason), None
|
||||
)
|
||||
|
||||
if expect is not None and not isinstance(event, expect):
|
||||
raise AssertionError(
|
||||
f"Expected {expect!r} websocket event, got {type(event)!r}"
|
||||
)
|
||||
|
||||
# this keeps compatibility with code written for aaugustin/websockets
|
||||
if expect is Message and process_event:
|
||||
return event.data
|
||||
|
||||
return event
|
||||
|
||||
async def close(self, close_code: int, close_reason: str):
|
||||
log.info("closing connection")
|
||||
event = CloseConnection(code=close_code, reason=close_reason)
|
||||
await self.send(event)
|
||||
self.writer.close()
|
||||
await self.writer.wait_closed()
|
||||
self.ws.receive_data(None)
|
||||
|
||||
async def connect(self):
|
||||
parsed = urllib.parse.urlparse(self.url)
|
||||
if parsed.scheme == "wss":
|
||||
port = 443
|
||||
elif parsed.scheme == "ws":
|
||||
port = 80
|
||||
else:
|
||||
raise AssertionError("Invalid url scheme")
|
||||
|
||||
host, *rest = parsed.netloc.split(":")
|
||||
if rest:
|
||||
port = rest[0]
|
||||
|
||||
log.info("connecting to {!r} {}", host, port)
|
||||
self.reader, self.writer = await asyncio.open_connection(host, port)
|
||||
|
||||
path = parsed.path or "/"
|
||||
target = f"{path}?{parsed.query}" if parsed.query else path
|
||||
await self.send(Request(host=parsed.netloc, target=target))
|
||||
await self.recv(expect=AcceptConnection)
|
||||
|
||||
|
||||
async def _recv(conn, *, zlib_stream: bool):
|
||||
if zlib_stream:
|
||||
try:
|
||||
|
|
@ -43,11 +159,16 @@ async def _recv(conn, *, zlib_stream: bool):
|
|||
zlib_buffer = bytearray()
|
||||
while True:
|
||||
# keep receiving frames until we find the zlib prefix inside
|
||||
msg = await conn.recv()
|
||||
zlib_buffer.extend(msg)
|
||||
if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX:
|
||||
# we set process_event to false so that we get the entire event
|
||||
# instead of only data
|
||||
event = await conn.recv(process_event=False)
|
||||
zlib_buffer.extend(event.data)
|
||||
if not event.message_finished:
|
||||
continue
|
||||
|
||||
if len(zlib_buffer) < 4 or zlib_buffer[-4:] != ZLIB_SUFFIX:
|
||||
raise RuntimeError("Finished compressed message without ZLIB suffix")
|
||||
|
||||
# NOTE: the message is utf-8 encoded.
|
||||
msg = conn._zlib_context.decompress(zlib_buffer)
|
||||
return msg
|
||||
|
|
@ -119,7 +240,10 @@ async def gw_start(
|
|||
gw_url = f"{gw_url}?v={version}&encoding=json"
|
||||
|
||||
compress = f"&compress={compress}" if compress else ""
|
||||
return await websockets.connect(f"{gw_url}{compress}")
|
||||
|
||||
ws = AsyncWebsocket(f"{gw_url}{compress}")
|
||||
await ws.connect()
|
||||
return ws
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -174,8 +298,6 @@ async def test_broken_identify(test_cli_user):
|
|||
raise AssertionError("Received a JSON message but expected close")
|
||||
except websockets.ConnectionClosed as exc:
|
||||
assert exc.code == 4002
|
||||
finally:
|
||||
await _close(conn)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Reference in New Issue