diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 4c0c7f8..27b6f6e 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -68,7 +68,7 @@ class AsyncWebsocket: self.writer.write(self.ws.send(data)) await self.writer.drain() - async def recv(self, *, expect=Message, message_str: bool = True): + async def recv(self, *, expect=Message, process_event: bool = True): in_data = await self.reader.read(4096) if not in_data: log.info("connection closed (no data)") @@ -85,9 +85,11 @@ class AsyncWebsocket: event = next(self.ws.events()) if isinstance(event, CloseConnection): - raise websockets.ConnectionClosed( - RcvdWrapper(event.code, event.reason), None - ) + await self.send(event.response()) + if process_event: + raise websockets.ConnectionClosed( + RcvdWrapper(event.code, event.reason), None + ) if expect is not None and not isinstance(event, expect): raise AssertionError( @@ -95,7 +97,7 @@ class AsyncWebsocket: ) # this keeps compatibility with code written for aaugustin/websockets - if expect is Message and message_str: + if expect is Message and process_event: return event.data return event @@ -141,9 +143,9 @@ async def _recv(conn, *, zlib_stream: bool): zlib_buffer = bytearray() while True: # keep receiving frames until we find the zlib prefix inside - # we set message_str to false so that we get the entire event + # we set process_event to false so that we get the entire event # instead of only data - event = await conn.recv(message_str=False) + event = await conn.recv(process_event=False) zlib_buffer.extend(event.data) if not event.message_finished: continue