From f3979221addceb25fd5980ddf2978eb137812423 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jan 2022 00:00:34 -0300 Subject: [PATCH 01/10] add wsproto dep --- poetry.lock | 2 +- pyproject.toml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 3c2ecb0..00b78f1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -424,7 +424,7 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "56d99717b6f3a32560d33ee9efdb77c7cf59f2447f1ff626c8dabb1261dc30c7" +content-hash = "135c208a1dd82e1f358357e6ce26802d134e57f6fb8966d634c8b200c9730d96" [metadata.files] aiofiles = [ diff --git a/pyproject.toml b/pyproject.toml index 75297e3..31ce245 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ pillow = "^8.3.2" aiohttp = "^3.7.4" zstandard = "^0.15.2" winter = {git = "https://gitlab.com/elixire/winter"} +wsproto = "^1.0.0" From 0faae9fafda46a3ee3a54f1caa7a2ce637a9cda0 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jan 2022 00:00:45 -0300 Subject: [PATCH 02/10] tests: add websockets->wsproto translation layer see #139 --- tests/test_websocket.py | 115 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 111 insertions(+), 4 deletions(-) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 0d12787..f18c7fd 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -19,10 +19,23 @@ along with this program. If not, see . import json import zlib +import asyncio +import urllib.parse +import collections from typing import Optional import pytest import websockets +from logbook import Logger +from wsproto import WSConnection, ConnectionType +from wsproto.events import ( + Request, + Message, + AcceptConnection, + CloseConnection, + Ping, + Pong, +) from litecord.gateway.opcodes import OP from litecord.gateway.websocket import decode_etf @@ -31,6 +44,92 @@ from litecord.gateway.websocket import decode_etf ZLIB_SUFFIX = b"\x00\x00\xff\xff" +log = Logger("test_websocket") + +RcvdWrapper = collections.namedtuple("RcvdWrapper", "code reason") + + +class AsyncWebsocket: + """websockets-compatible websocket object""" + + def __init__(self, url): + self.url = url + self.ws = WSConnection(ConnectionType.CLIENT) + self.reader, self.writer = None, None + + async def send(self, data): + assert self.writer is not None + + # wrap all strings in Message + if isinstance(data, str): + data = Message(data=data) + + log.debug("sending {} event", type(data)) + + self.writer.write(self.ws.send(data)) + await self.writer.drain() + + async def recv(self, *, expect=Message, message_str: bool = True): + in_data = await self.reader.read(4096) + if not in_data: + log.info("connection closed (no data)") + self.ws.receive_data(None) + else: + log.debug("received {} bytes", len(in_data)) + self.ws.receive_data(in_data) + + # if we get a ping, reply with pong immediately + # and fetch the next event + event = next(self.ws.events()) + if isinstance(event, Ping): + await self.send(Pong()) + event = next(self.ws.events()) + + if isinstance(event, CloseConnection): + raise websockets.ConnectionClosed( + RcvdWrapper(event.code, event.reason), None + ) + + if expect is not None and not isinstance(event, expect): + raise AssertionError( + f"Expected {expect!r} websocket event, got {type(event)!r}" + ) + + # this keeps compatibility with code written for aaugustin/websockets + if expect is Message and message_str: + return event.data + + return event + + async def close(self, close_code: int, close_reason: str): + log.info("closing connection") + event = CloseConnection(code=close_code, reason=close_reason) + await self.send(event) + self.writer.close() + await self.writer.wait_closed() + self.ws.receive_data(None) + + async def connect(self): + parsed = urllib.parse.urlparse(self.url) + if parsed.scheme == "wss": + port = 443 + elif parsed.scheme == "ws": + port = 80 + else: + raise AssertionError("Invalid url scheme") + + host, *rest = parsed.netloc.split(":") + if rest: + port = rest[0] + + log.info("connecting to {!r} {}", host, port) + self.reader, self.writer = await asyncio.open_connection(host, port) + path = parsed.path or "/" + target = f"{path}?{parsed.query}" if parsed.query else path + await self.send(Request(host=parsed.netloc, target=target)) + await self.recv(expect=AcceptConnection) + + async def _recv(conn, *, zlib_stream: bool): if zlib_stream: try: @@ -43,11 +142,16 @@ async def _recv(conn, *, zlib_stream: bool): zlib_buffer = bytearray() while True: # keep receiving frames until we find the zlib prefix inside - msg = await conn.recv() - zlib_buffer.extend(msg) - if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX: + # we set message_str to false so that we get the entire event + # instead of only data + event = await conn.recv(message_str=False) + zlib_buffer.extend(event.data) + if not event.message_finished: continue + if len(zlib_buffer) < 4 or zlib_buffer[-4:] != ZLIB_SUFFIX: + raise RuntimeError("Finished compressed message without ZLIB suffix") + # NOTE: the message is utf-8 encoded. msg = conn._zlib_context.decompress(zlib_buffer) return msg @@ -119,7 +223,10 @@ async def gw_start( gw_url = f"{gw_url}?v={version}&encoding=json" compress = f"&compress={compress}" if compress else "" - return await websockets.connect(f"{gw_url}{compress}") + + ws = AsyncWebsocket(f"{gw_url}{compress}") + await ws.connect() + return ws @pytest.mark.asyncio From 9b98257741895c3f6b1a0a5ce7cf4812c33b92fa Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jan 2022 00:24:24 -0300 Subject: [PATCH 03/10] tests: reply with correct Pong event --- tests/test_websocket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index f18c7fd..e66ce06 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -82,7 +82,7 @@ class AsyncWebsocket: # and fetch the next event event = next(self.ws.events()) if isinstance(event, Ping): - await self.send(Pong()) + await self.send(event.response()) event = next(self.ws.events()) if isinstance(event, CloseConnection): From 94e1f16be5c4b1db87059840f8400be0f3d263fd Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jan 2022 00:24:58 -0300 Subject: [PATCH 04/10] remove unused import --- tests/test_websocket.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index e66ce06..4c0c7f8 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -34,7 +34,6 @@ from wsproto.events import ( AcceptConnection, CloseConnection, Ping, - Pong, ) from litecord.gateway.opcodes import OP From 28299607cd1b88e754da3b0b0d043745a6196759 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jan 2022 00:28:25 -0300 Subject: [PATCH 05/10] tests: send correct protocol reply on close event --- tests/test_websocket.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 4c0c7f8..27b6f6e 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -68,7 +68,7 @@ class AsyncWebsocket: self.writer.write(self.ws.send(data)) await self.writer.drain() - async def recv(self, *, expect=Message, message_str: bool = True): + async def recv(self, *, expect=Message, process_event: bool = True): in_data = await self.reader.read(4096) if not in_data: log.info("connection closed (no data)") @@ -85,9 +85,11 @@ class AsyncWebsocket: event = next(self.ws.events()) if isinstance(event, CloseConnection): - raise websockets.ConnectionClosed( - RcvdWrapper(event.code, event.reason), None - ) + await self.send(event.response()) + if process_event: + raise websockets.ConnectionClosed( + RcvdWrapper(event.code, event.reason), None + ) if expect is not None and not isinstance(event, expect): raise AssertionError( @@ -95,7 +97,7 @@ class AsyncWebsocket: ) # this keeps compatibility with code written for aaugustin/websockets - if expect is Message and message_str: + if expect is Message and process_event: return event.data return event @@ -141,9 +143,9 @@ async def _recv(conn, *, zlib_stream: bool): zlib_buffer = bytearray() while True: # keep receiving frames until we find the zlib prefix inside - # we set message_str to false so that we get the entire event + # we set process_event to false so that we get the entire event # instead of only data - event = await conn.recv(message_str=False) + event = await conn.recv(process_event=False) zlib_buffer.extend(event.data) if not event.message_finished: continue From 38b560205fbf6fb050d6860611531faeb9d08ef5 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jan 2022 23:25:39 -0300 Subject: [PATCH 06/10] tests: use an internal queue for wsproto events --- tests/test_websocket.py | 77 +++++++++++++++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 14 deletions(-) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 27b6f6e..6cb0d58 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -55,6 +55,9 @@ class AsyncWebsocket: self.url = url self.ws = WSConnection(ConnectionType.CLIENT) self.reader, self.writer = None, None + self.reader_task = None + self._waiting_for_message_event = None + self._events = asyncio.Queue() async def send(self, data): assert self.writer is not None @@ -68,21 +71,65 @@ class AsyncWebsocket: self.writer.write(self.ws.send(data)) await self.writer.drain() - async def recv(self, *, expect=Message, process_event: bool = True): - in_data = await self.reader.read(4096) - if not in_data: - log.info("connection closed (no data)") - self.ws.receive_data(None) - else: - log.debug("received {} bytes", len(in_data)) - self.ws.receive_data(in_data) + async def _reader_loop_task(self): - # if we get a ping, reply with pong immediately - # and fetch the next event - event = next(self.ws.events()) - if isinstance(event, Ping): - await self.send(event.response()) - event = next(self.ws.events()) + # continuously read messages from the socket + # and fill up the _events queue with them + # + # if a recv() coroutine has been waiting for an event + # (via _waiting_for_message_event), then set that event so that + # we immediately process it + while True: + log.info("reading data") + in_data = await self.reader.read(4096) + if not in_data: + log.info("connection closed (no data)") + self.ws.receive_data(None) + else: + log.debug("received {} bytes", len(in_data)) + self.ws.receive_data(in_data) + + for event in self.ws.events(): + log.debug("queued ws event {}", event) + await self._events.put(event) + + if not self._events.empty() and self._waiting_for_message_event: + self._waiting_for_message_event.set() + + # since we closed, we don't have to continue reading + if not in_data: + return + + async def recv(self, *, expect=Message, process_event: bool = True): + + # if queue is empty, wait until it's filled up + if self._events.empty(): + self._waiting_for_message_event = asyncio.Event() + try: + await asyncio.wait( + [self._waiting_for_message_event.wait(), self.reader_task], + return_when=asyncio.FIRST_COMPLETED, + ) + finally: + self._waiting_for_message_event = None + + # this loop is only done so we reply to pings while also being + # able to receive any other event in the middle. + # + # CloseConnection does not lead us to reading other events, so + # that's why it's left out. + + while True: + # if we get a ping, reply with pong immediately + # and fetch the next event + event = self._events.get_nowait() + log.debug("processing {}", event) + + if isinstance(event, Ping): + await self.send(event.response()) + continue + + break if isinstance(event, CloseConnection): await self.send(event.response()) @@ -125,6 +172,8 @@ class AsyncWebsocket: log.info("connecting to {!r} {}", host, port) self.reader, self.writer = await asyncio.open_connection(host, port) + self.reader_task = asyncio.create_task(self._reader_loop_task()) + path = parsed.path or "/" target = f"{path}?{parsed.query}" if parsed.query else path await self.send(Request(host=parsed.netloc, target=target)) From 049523b03f5cdf424a40119b636f6069d4a638f3 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jan 2022 23:38:12 -0300 Subject: [PATCH 07/10] tests: assert we are in a good state on autoreply --- tests/test_websocket.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 6cb0d58..f436c43 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -28,6 +28,7 @@ import pytest import websockets from logbook import Logger from wsproto import WSConnection, ConnectionType +from wsproto.connection import ConnectionState from wsproto.events import ( Request, Message, @@ -132,6 +133,7 @@ class AsyncWebsocket: break if isinstance(event, CloseConnection): + assert self.ws.state is ConnectionState.REMOTE_CLOSING await self.send(event.response()) if process_event: raise websockets.ConnectionClosed( From 4a70d9580d8ccfae9e95db5ea0fbcf762ea655bf Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jan 2022 23:38:33 -0300 Subject: [PATCH 08/10] tests: don't double-close on test_broken_identify --- tests/test_websocket.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index f436c43..6344bd4 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -333,8 +333,6 @@ async def test_broken_identify(test_cli_user): raise AssertionError("Received a JSON message but expected close") except websockets.ConnectionClosed as exc: assert exc.code == 4002 - finally: - await _close(conn) @pytest.mark.asyncio From 2180fbca0257770d13684fc1bbad7178883d2043 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 28 Jan 2022 23:52:46 -0300 Subject: [PATCH 09/10] tests: remove unecessary asyncio event queue's get() already blocks if no items are in the queue. --- tests/test_websocket.py | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 6344bd4..284123c 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -57,7 +57,6 @@ class AsyncWebsocket: self.ws = WSConnection(ConnectionType.CLIENT) self.reader, self.writer = None, None self.reader_task = None - self._waiting_for_message_event = None self._events = asyncio.Queue() async def send(self, data): @@ -73,13 +72,8 @@ class AsyncWebsocket: await self.writer.drain() async def _reader_loop_task(self): - # continuously read messages from the socket # and fill up the _events queue with them - # - # if a recv() coroutine has been waiting for an event - # (via _waiting_for_message_event), then set that event so that - # we immediately process it while True: log.info("reading data") in_data = await self.reader.read(4096) @@ -94,26 +88,12 @@ class AsyncWebsocket: log.debug("queued ws event {}", event) await self._events.put(event) - if not self._events.empty() and self._waiting_for_message_event: - self._waiting_for_message_event.set() - # since we closed, we don't have to continue reading if not in_data: return async def recv(self, *, expect=Message, process_event: bool = True): - # if queue is empty, wait until it's filled up - if self._events.empty(): - self._waiting_for_message_event = asyncio.Event() - try: - await asyncio.wait( - [self._waiting_for_message_event.wait(), self.reader_task], - return_when=asyncio.FIRST_COMPLETED, - ) - finally: - self._waiting_for_message_event = None - # this loop is only done so we reply to pings while also being # able to receive any other event in the middle. # @@ -123,7 +103,7 @@ class AsyncWebsocket: while True: # if we get a ping, reply with pong immediately # and fetch the next event - event = self._events.get_nowait() + event = await self._events.get() log.debug("processing {}", event) if isinstance(event, Ping): From df0a77002e5edcefc7d9b95e33da62bab3bd60b1 Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 29 Jan 2022 18:26:05 -0300 Subject: [PATCH 10/10] tests: remove unecessary reader task --- tests/test_websocket.py | 39 ++++++++++++--------------------------- 1 file changed, 12 insertions(+), 27 deletions(-) diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 284123c..ee8d3fe 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -56,8 +56,6 @@ class AsyncWebsocket: self.url = url self.ws = WSConnection(ConnectionType.CLIENT) self.reader, self.writer = None, None - self.reader_task = None - self._events = asyncio.Queue() async def send(self, data): assert self.writer is not None @@ -71,27 +69,6 @@ class AsyncWebsocket: self.writer.write(self.ws.send(data)) await self.writer.drain() - async def _reader_loop_task(self): - # continuously read messages from the socket - # and fill up the _events queue with them - while True: - log.info("reading data") - in_data = await self.reader.read(4096) - if not in_data: - log.info("connection closed (no data)") - self.ws.receive_data(None) - else: - log.debug("received {} bytes", len(in_data)) - self.ws.receive_data(in_data) - - for event in self.ws.events(): - log.debug("queued ws event {}", event) - await self._events.put(event) - - # since we closed, we don't have to continue reading - if not in_data: - return - async def recv(self, *, expect=Message, process_event: bool = True): # this loop is only done so we reply to pings while also being @@ -101,11 +78,20 @@ class AsyncWebsocket: # that's why it's left out. while True: + # if there's already an unprocessed event we can try getting + # it from wsproto first + event = None + for event in self.ws.events(): + break + + if event is None: + data = await self.reader.read(4096) + assert data # We expect the WebSocket to be closed correctly + self.ws.receive_data(data) + continue + # if we get a ping, reply with pong immediately # and fetch the next event - event = await self._events.get() - log.debug("processing {}", event) - if isinstance(event, Ping): await self.send(event.response()) continue @@ -154,7 +140,6 @@ class AsyncWebsocket: log.info("connecting to {!r} {}", host, port) self.reader, self.writer = await asyncio.open_connection(host, port) - self.reader_task = asyncio.create_task(self._reader_loop_task()) path = parsed.path or "/" target = f"{path}?{parsed.query}" if parsed.query else path