mirror of https://gitlab.com/litecord/litecord.git
Compare commits
1 Commits
fc16231b9a
...
8fa78b47f5
| Author | SHA1 | Date |
|---|---|---|
|
|
8fa78b47f5 |
|
|
@ -16,3 +16,4 @@ You should have received a copy of the GNU General Public License
|
||||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -94,23 +94,6 @@ async def adduser(ctx, args):
|
||||||
print(f'\tdiscrim: {user["discriminator"]}')
|
print(f'\tdiscrim: {user["discriminator"]}')
|
||||||
|
|
||||||
|
|
||||||
async def addbot(ctx, args):
|
|
||||||
uid, _ = await create_user(args.username, args.email, args.password)
|
|
||||||
|
|
||||||
await ctx.db.execute(
|
|
||||||
"""
|
|
||||||
UPDATE users
|
|
||||||
SET bot=True
|
|
||||||
WHERE id = $1
|
|
||||||
""",
|
|
||||||
uid,
|
|
||||||
)
|
|
||||||
|
|
||||||
args.user_id = uid
|
|
||||||
|
|
||||||
return await generate_bot_token(ctx, args)
|
|
||||||
|
|
||||||
|
|
||||||
async def set_flag(ctx, args):
|
async def set_flag(ctx, args):
|
||||||
"""Setting a 'staff' flag gives the user access to the Admin API.
|
"""Setting a 'staff' flag gives the user access to the Admin API.
|
||||||
Beware of that.
|
Beware of that.
|
||||||
|
|
@ -155,8 +138,7 @@ async def generate_bot_token(ctx, args):
|
||||||
)
|
)
|
||||||
|
|
||||||
if not password_hash:
|
if not password_hash:
|
||||||
print("cannot find a bot with specified id")
|
return print("cannot find a bot with specified id")
|
||||||
return 1
|
|
||||||
|
|
||||||
print(make_token(args.user_id, password_hash))
|
print(make_token(args.user_id, password_hash))
|
||||||
|
|
||||||
|
|
@ -216,14 +198,6 @@ def setup(subparser):
|
||||||
|
|
||||||
setup_test_parser.set_defaults(func=adduser)
|
setup_test_parser.set_defaults(func=adduser)
|
||||||
|
|
||||||
addbot_parser = subparser.add_parser("addbot", help="create a bot")
|
|
||||||
|
|
||||||
addbot_parser.add_argument("username", help="username of the bot")
|
|
||||||
addbot_parser.add_argument("email", help="email of the bot")
|
|
||||||
addbot_parser.add_argument("password", help="password of the bot")
|
|
||||||
|
|
||||||
addbot_parser.set_defaults(func=addbot)
|
|
||||||
|
|
||||||
setflag_parser = subparser.add_parser(
|
setflag_parser = subparser.add_parser(
|
||||||
"setflag", help="set a flag for a user", description=set_flag.__doc__
|
"setflag", help="set a flag for a user", description=set_flag.__doc__
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,7 @@ def main(config):
|
||||||
async def _ctx_wrapper(fake_app, args):
|
async def _ctx_wrapper(fake_app, args):
|
||||||
app = fake_app.make_app()
|
app = fake_app.make_app()
|
||||||
async with app.app_context():
|
async with app.app_context():
|
||||||
return await args.func(fake_app, args)
|
await args.func(fake_app, args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if len(argv) < 2:
|
if len(argv) < 2:
|
||||||
|
|
@ -107,9 +107,8 @@ def main(config):
|
||||||
init_app_managers(app, init_voice=False)
|
init_app_managers(app, init_voice=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return loop.run_until_complete(_ctx_wrapper(app, args))
|
loop.run_until_complete(_ctx_wrapper(app, args))
|
||||||
except Exception:
|
except Exception:
|
||||||
log.exception("error while running command")
|
log.exception("error while running command")
|
||||||
return 1
|
|
||||||
finally:
|
finally:
|
||||||
loop.run_until_complete(app.db.close())
|
loop.run_until_complete(app.db.close())
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -19,7 +19,6 @@ pillow = "^8.3.2"
|
||||||
aiohttp = "^3.7.4"
|
aiohttp = "^3.7.4"
|
||||||
zstandard = "^0.15.2"
|
zstandard = "^0.15.2"
|
||||||
winter = {git = "https://gitlab.com/elixire/winter"}
|
winter = {git = "https://gitlab.com/elixire/winter"}
|
||||||
wsproto = "^1.0.0"
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -19,23 +19,10 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import zlib
|
import zlib
|
||||||
import asyncio
|
|
||||||
import urllib.parse
|
|
||||||
import collections
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import websockets
|
import websockets
|
||||||
from logbook import Logger
|
|
||||||
from wsproto import WSConnection, ConnectionType
|
|
||||||
from wsproto.connection import ConnectionState
|
|
||||||
from wsproto.events import (
|
|
||||||
Request,
|
|
||||||
Message,
|
|
||||||
AcceptConnection,
|
|
||||||
CloseConnection,
|
|
||||||
Ping,
|
|
||||||
)
|
|
||||||
|
|
||||||
from litecord.gateway.opcodes import OP
|
from litecord.gateway.opcodes import OP
|
||||||
from litecord.gateway.websocket import decode_etf
|
from litecord.gateway.websocket import decode_etf
|
||||||
|
|
@ -44,109 +31,6 @@ from litecord.gateway.websocket import decode_etf
|
||||||
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
||||||
|
|
||||||
|
|
||||||
log = Logger("test_websocket")
|
|
||||||
|
|
||||||
RcvdWrapper = collections.namedtuple("RcvdWrapper", "code reason")
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncWebsocket:
|
|
||||||
"""websockets-compatible websocket object"""
|
|
||||||
|
|
||||||
def __init__(self, url):
|
|
||||||
self.url = url
|
|
||||||
self.ws = WSConnection(ConnectionType.CLIENT)
|
|
||||||
self.reader, self.writer = None, None
|
|
||||||
|
|
||||||
async def send(self, data):
|
|
||||||
assert self.writer is not None
|
|
||||||
|
|
||||||
# wrap all strings in Message
|
|
||||||
if isinstance(data, str):
|
|
||||||
data = Message(data=data)
|
|
||||||
|
|
||||||
log.debug("sending {} event", type(data))
|
|
||||||
|
|
||||||
self.writer.write(self.ws.send(data))
|
|
||||||
await self.writer.drain()
|
|
||||||
|
|
||||||
async def recv(self, *, expect=Message, process_event: bool = True):
|
|
||||||
|
|
||||||
# this loop is only done so we reply to pings while also being
|
|
||||||
# able to receive any other event in the middle.
|
|
||||||
#
|
|
||||||
# CloseConnection does not lead us to reading other events, so
|
|
||||||
# that's why it's left out.
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# if there's already an unprocessed event we can try getting
|
|
||||||
# it from wsproto first
|
|
||||||
event = None
|
|
||||||
for event in self.ws.events():
|
|
||||||
break
|
|
||||||
|
|
||||||
if event is None:
|
|
||||||
data = await self.reader.read(4096)
|
|
||||||
assert data # We expect the WebSocket to be closed correctly
|
|
||||||
self.ws.receive_data(data)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# if we get a ping, reply with pong immediately
|
|
||||||
# and fetch the next event
|
|
||||||
if isinstance(event, Ping):
|
|
||||||
await self.send(event.response())
|
|
||||||
continue
|
|
||||||
|
|
||||||
break
|
|
||||||
|
|
||||||
if isinstance(event, CloseConnection):
|
|
||||||
assert self.ws.state is ConnectionState.REMOTE_CLOSING
|
|
||||||
await self.send(event.response())
|
|
||||||
if process_event:
|
|
||||||
raise websockets.ConnectionClosed(
|
|
||||||
RcvdWrapper(event.code, event.reason), None
|
|
||||||
)
|
|
||||||
|
|
||||||
if expect is not None and not isinstance(event, expect):
|
|
||||||
raise AssertionError(
|
|
||||||
f"Expected {expect!r} websocket event, got {type(event)!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# this keeps compatibility with code written for aaugustin/websockets
|
|
||||||
if expect is Message and process_event:
|
|
||||||
return event.data
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|
||||||
async def close(self, close_code: int, close_reason: str):
|
|
||||||
log.info("closing connection")
|
|
||||||
event = CloseConnection(code=close_code, reason=close_reason)
|
|
||||||
await self.send(event)
|
|
||||||
self.writer.close()
|
|
||||||
await self.writer.wait_closed()
|
|
||||||
self.ws.receive_data(None)
|
|
||||||
|
|
||||||
async def connect(self):
|
|
||||||
parsed = urllib.parse.urlparse(self.url)
|
|
||||||
if parsed.scheme == "wss":
|
|
||||||
port = 443
|
|
||||||
elif parsed.scheme == "ws":
|
|
||||||
port = 80
|
|
||||||
else:
|
|
||||||
raise AssertionError("Invalid url scheme")
|
|
||||||
|
|
||||||
host, *rest = parsed.netloc.split(":")
|
|
||||||
if rest:
|
|
||||||
port = rest[0]
|
|
||||||
|
|
||||||
log.info("connecting to {!r} {}", host, port)
|
|
||||||
self.reader, self.writer = await asyncio.open_connection(host, port)
|
|
||||||
|
|
||||||
path = parsed.path or "/"
|
|
||||||
target = f"{path}?{parsed.query}" if parsed.query else path
|
|
||||||
await self.send(Request(host=parsed.netloc, target=target))
|
|
||||||
await self.recv(expect=AcceptConnection)
|
|
||||||
|
|
||||||
|
|
||||||
async def _recv(conn, *, zlib_stream: bool):
|
async def _recv(conn, *, zlib_stream: bool):
|
||||||
if zlib_stream:
|
if zlib_stream:
|
||||||
try:
|
try:
|
||||||
|
|
@ -159,16 +43,11 @@ async def _recv(conn, *, zlib_stream: bool):
|
||||||
zlib_buffer = bytearray()
|
zlib_buffer = bytearray()
|
||||||
while True:
|
while True:
|
||||||
# keep receiving frames until we find the zlib prefix inside
|
# keep receiving frames until we find the zlib prefix inside
|
||||||
# we set process_event to false so that we get the entire event
|
msg = await conn.recv()
|
||||||
# instead of only data
|
zlib_buffer.extend(msg)
|
||||||
event = await conn.recv(process_event=False)
|
if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX:
|
||||||
zlib_buffer.extend(event.data)
|
|
||||||
if not event.message_finished:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(zlib_buffer) < 4 or zlib_buffer[-4:] != ZLIB_SUFFIX:
|
|
||||||
raise RuntimeError("Finished compressed message without ZLIB suffix")
|
|
||||||
|
|
||||||
# NOTE: the message is utf-8 encoded.
|
# NOTE: the message is utf-8 encoded.
|
||||||
msg = conn._zlib_context.decompress(zlib_buffer)
|
msg = conn._zlib_context.decompress(zlib_buffer)
|
||||||
return msg
|
return msg
|
||||||
|
|
@ -240,10 +119,7 @@ async def gw_start(
|
||||||
gw_url = f"{gw_url}?v={version}&encoding=json"
|
gw_url = f"{gw_url}?v={version}&encoding=json"
|
||||||
|
|
||||||
compress = f"&compress={compress}" if compress else ""
|
compress = f"&compress={compress}" if compress else ""
|
||||||
|
return await websockets.connect(f"{gw_url}{compress}")
|
||||||
ws = AsyncWebsocket(f"{gw_url}{compress}")
|
|
||||||
await ws.connect()
|
|
||||||
return ws
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -298,6 +174,8 @@ async def test_broken_identify(test_cli_user):
|
||||||
raise AssertionError("Received a JSON message but expected close")
|
raise AssertionError("Received a JSON message but expected close")
|
||||||
except websockets.ConnectionClosed as exc:
|
except websockets.ConnectionClosed as exc:
|
||||||
assert exc.code == 4002
|
assert exc.code == 4002
|
||||||
|
finally:
|
||||||
|
await _close(conn)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue