mirror of https://gitlab.com/litecord/litecord.git
Compare commits
4 Commits
28299607cd
...
2180fbca02
| Author | SHA1 | Date |
|---|---|---|
|
|
2180fbca02 | |
|
|
4a70d9580d | |
|
|
049523b03f | |
|
|
38b560205f |
|
|
@ -28,6 +28,7 @@ import pytest
|
|||
import websockets
|
||||
from logbook import Logger
|
||||
from wsproto import WSConnection, ConnectionType
|
||||
from wsproto.connection import ConnectionState
|
||||
from wsproto.events import (
|
||||
Request,
|
||||
Message,
|
||||
|
|
@ -55,6 +56,8 @@ class AsyncWebsocket:
|
|||
self.url = url
|
||||
self.ws = WSConnection(ConnectionType.CLIENT)
|
||||
self.reader, self.writer = None, None
|
||||
self.reader_task = None
|
||||
self._events = asyncio.Queue()
|
||||
|
||||
async def send(self, data):
|
||||
assert self.writer is not None
|
||||
|
|
@ -68,23 +71,49 @@ class AsyncWebsocket:
|
|||
self.writer.write(self.ws.send(data))
|
||||
await self.writer.drain()
|
||||
|
||||
async def recv(self, *, expect=Message, process_event: bool = True):
|
||||
in_data = await self.reader.read(4096)
|
||||
if not in_data:
|
||||
log.info("connection closed (no data)")
|
||||
self.ws.receive_data(None)
|
||||
else:
|
||||
log.debug("received {} bytes", len(in_data))
|
||||
self.ws.receive_data(in_data)
|
||||
async def _reader_loop_task(self):
|
||||
# continuously read messages from the socket
|
||||
# and fill up the _events queue with them
|
||||
while True:
|
||||
log.info("reading data")
|
||||
in_data = await self.reader.read(4096)
|
||||
if not in_data:
|
||||
log.info("connection closed (no data)")
|
||||
self.ws.receive_data(None)
|
||||
else:
|
||||
log.debug("received {} bytes", len(in_data))
|
||||
self.ws.receive_data(in_data)
|
||||
|
||||
# if we get a ping, reply with pong immediately
|
||||
# and fetch the next event
|
||||
event = next(self.ws.events())
|
||||
if isinstance(event, Ping):
|
||||
await self.send(event.response())
|
||||
event = next(self.ws.events())
|
||||
for event in self.ws.events():
|
||||
log.debug("queued ws event {}", event)
|
||||
await self._events.put(event)
|
||||
|
||||
# since we closed, we don't have to continue reading
|
||||
if not in_data:
|
||||
return
|
||||
|
||||
async def recv(self, *, expect=Message, process_event: bool = True):
|
||||
|
||||
# this loop is only done so we reply to pings while also being
|
||||
# able to receive any other event in the middle.
|
||||
#
|
||||
# CloseConnection does not lead us to reading other events, so
|
||||
# that's why it's left out.
|
||||
|
||||
while True:
|
||||
# if we get a ping, reply with pong immediately
|
||||
# and fetch the next event
|
||||
event = await self._events.get()
|
||||
log.debug("processing {}", event)
|
||||
|
||||
if isinstance(event, Ping):
|
||||
await self.send(event.response())
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
if isinstance(event, CloseConnection):
|
||||
assert self.ws.state is ConnectionState.REMOTE_CLOSING
|
||||
await self.send(event.response())
|
||||
if process_event:
|
||||
raise websockets.ConnectionClosed(
|
||||
|
|
@ -125,6 +154,8 @@ class AsyncWebsocket:
|
|||
|
||||
log.info("connecting to {!r} {}", host, port)
|
||||
self.reader, self.writer = await asyncio.open_connection(host, port)
|
||||
self.reader_task = asyncio.create_task(self._reader_loop_task())
|
||||
|
||||
path = parsed.path or "/"
|
||||
target = f"{path}?{parsed.query}" if parsed.query else path
|
||||
await self.send(Request(host=parsed.netloc, target=target))
|
||||
|
|
@ -282,8 +313,6 @@ async def test_broken_identify(test_cli_user):
|
|||
raise AssertionError("Received a JSON message but expected close")
|
||||
except websockets.ConnectionClosed as exc:
|
||||
assert exc.code == 4002
|
||||
finally:
|
||||
await _close(conn)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Reference in New Issue