mirror of https://gitlab.com/litecord/litecord.git
tests: use an internal queue for wsproto events
This commit is contained in:
parent
28299607cd
commit
38b560205f
|
|
@ -55,6 +55,9 @@ class AsyncWebsocket:
|
||||||
self.url = url
|
self.url = url
|
||||||
self.ws = WSConnection(ConnectionType.CLIENT)
|
self.ws = WSConnection(ConnectionType.CLIENT)
|
||||||
self.reader, self.writer = None, None
|
self.reader, self.writer = None, None
|
||||||
|
self.reader_task = None
|
||||||
|
self._waiting_for_message_event = None
|
||||||
|
self._events = asyncio.Queue()
|
||||||
|
|
||||||
async def send(self, data):
|
async def send(self, data):
|
||||||
assert self.writer is not None
|
assert self.writer is not None
|
||||||
|
|
@ -68,21 +71,65 @@ class AsyncWebsocket:
|
||||||
self.writer.write(self.ws.send(data))
|
self.writer.write(self.ws.send(data))
|
||||||
await self.writer.drain()
|
await self.writer.drain()
|
||||||
|
|
||||||
async def recv(self, *, expect=Message, process_event: bool = True):
|
async def _reader_loop_task(self):
|
||||||
in_data = await self.reader.read(4096)
|
|
||||||
if not in_data:
|
|
||||||
log.info("connection closed (no data)")
|
|
||||||
self.ws.receive_data(None)
|
|
||||||
else:
|
|
||||||
log.debug("received {} bytes", len(in_data))
|
|
||||||
self.ws.receive_data(in_data)
|
|
||||||
|
|
||||||
# if we get a ping, reply with pong immediately
|
# continuously read messages from the socket
|
||||||
# and fetch the next event
|
# and fill up the _events queue with them
|
||||||
event = next(self.ws.events())
|
#
|
||||||
if isinstance(event, Ping):
|
# if a recv() coroutine has been waiting for an event
|
||||||
await self.send(event.response())
|
# (via _waiting_for_message_event), then set that event so that
|
||||||
event = next(self.ws.events())
|
# we immediately process it
|
||||||
|
while True:
|
||||||
|
log.info("reading data")
|
||||||
|
in_data = await self.reader.read(4096)
|
||||||
|
if not in_data:
|
||||||
|
log.info("connection closed (no data)")
|
||||||
|
self.ws.receive_data(None)
|
||||||
|
else:
|
||||||
|
log.debug("received {} bytes", len(in_data))
|
||||||
|
self.ws.receive_data(in_data)
|
||||||
|
|
||||||
|
for event in self.ws.events():
|
||||||
|
log.debug("queued ws event {}", event)
|
||||||
|
await self._events.put(event)
|
||||||
|
|
||||||
|
if not self._events.empty() and self._waiting_for_message_event:
|
||||||
|
self._waiting_for_message_event.set()
|
||||||
|
|
||||||
|
# since we closed, we don't have to continue reading
|
||||||
|
if not in_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def recv(self, *, expect=Message, process_event: bool = True):
|
||||||
|
|
||||||
|
# if queue is empty, wait until it's filled up
|
||||||
|
if self._events.empty():
|
||||||
|
self._waiting_for_message_event = asyncio.Event()
|
||||||
|
try:
|
||||||
|
await asyncio.wait(
|
||||||
|
[self._waiting_for_message_event.wait(), self.reader_task],
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._waiting_for_message_event = None
|
||||||
|
|
||||||
|
# this loop is only done so we reply to pings while also being
|
||||||
|
# able to receive any other event in the middle.
|
||||||
|
#
|
||||||
|
# CloseConnection does not lead us to reading other events, so
|
||||||
|
# that's why it's left out.
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# if we get a ping, reply with pong immediately
|
||||||
|
# and fetch the next event
|
||||||
|
event = self._events.get_nowait()
|
||||||
|
log.debug("processing {}", event)
|
||||||
|
|
||||||
|
if isinstance(event, Ping):
|
||||||
|
await self.send(event.response())
|
||||||
|
continue
|
||||||
|
|
||||||
|
break
|
||||||
|
|
||||||
if isinstance(event, CloseConnection):
|
if isinstance(event, CloseConnection):
|
||||||
await self.send(event.response())
|
await self.send(event.response())
|
||||||
|
|
@ -125,6 +172,8 @@ class AsyncWebsocket:
|
||||||
|
|
||||||
log.info("connecting to {!r} {}", host, port)
|
log.info("connecting to {!r} {}", host, port)
|
||||||
self.reader, self.writer = await asyncio.open_connection(host, port)
|
self.reader, self.writer = await asyncio.open_connection(host, port)
|
||||||
|
self.reader_task = asyncio.create_task(self._reader_loop_task())
|
||||||
|
|
||||||
path = parsed.path or "/"
|
path = parsed.path or "/"
|
||||||
target = f"{path}?{parsed.query}" if parsed.query else path
|
target = f"{path}?{parsed.query}" if parsed.query else path
|
||||||
await self.send(Request(host=parsed.netloc, target=target))
|
await self.send(Request(host=parsed.netloc, target=target))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue