diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 284123c..ee8d3fe 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -56,8 +56,6 @@ class AsyncWebsocket: self.url = url self.ws = WSConnection(ConnectionType.CLIENT) self.reader, self.writer = None, None - self.reader_task = None - self._events = asyncio.Queue() async def send(self, data): assert self.writer is not None @@ -71,27 +69,6 @@ class AsyncWebsocket: self.writer.write(self.ws.send(data)) await self.writer.drain() - async def _reader_loop_task(self): - # continuously read messages from the socket - # and fill up the _events queue with them - while True: - log.info("reading data") - in_data = await self.reader.read(4096) - if not in_data: - log.info("connection closed (no data)") - self.ws.receive_data(None) - else: - log.debug("received {} bytes", len(in_data)) - self.ws.receive_data(in_data) - - for event in self.ws.events(): - log.debug("queued ws event {}", event) - await self._events.put(event) - - # since we closed, we don't have to continue reading - if not in_data: - return - async def recv(self, *, expect=Message, process_event: bool = True): # this loop is only done so we reply to pings while also being @@ -101,11 +78,20 @@ class AsyncWebsocket: # that's why it's left out. while True: + # if there's already an unprocessed event we can try getting + # it from wsproto first + event = None + for event in self.ws.events(): + break + + if event is None: + data = await self.reader.read(4096) + assert data # We expect the WebSocket to be closed correctly + self.ws.receive_data(data) + continue + # if we get a ping, reply with pong immediately # and fetch the next event - event = await self._events.get() - log.debug("processing {}", event) - if isinstance(event, Ping): await self.send(event.response()) continue @@ -154,7 +140,6 @@ class AsyncWebsocket: log.info("connecting to {!r} {}", host, port) self.reader, self.writer = await asyncio.open_connection(host, port) - self.reader_task = asyncio.create_task(self._reader_loop_task()) path = parsed.path or "/" target = f"{path}?{parsed.query}" if parsed.query else path