Compare commits

...

9 Commits

Author SHA1 Message Date
Mai fc16231b9a Merge branch 'patch-3' into 'master'
Channel names no longer end with a dash

See merge request litecord/litecord!70
2022-02-06 06:32:10 +00:00
luna 386354fd65 Merge branch 'refactor/wsproto-tests' into 'master'
tests: add websockets->wsproto translation layer

Closes #139

See merge request litecord/litecord!85
2022-01-29 23:38:21 +00:00
luna 3b87a17477 tests: add websockets->wsproto translation layer 2022-01-29 23:38:21 +00:00
Luna 6ac705f838 update dependencies 2022-01-27 21:56:16 -03:00
luna 6a617cf376 Merge branch 'addbot-manage' into 'master'
Add 'addbot' manage.py command

See merge request litecord/litecord!84
2022-01-27 14:40:26 +00:00
Bluenix 2fe50c4ac3 Add 'addbot' manage.py command 2022-01-27 14:40:26 +00:00
Mai 549d5992fd Finally got things to work 2020-09-12 05:07:51 +00:00
Mai fbc15219f7 I am terrible at python jfc 2020-09-10 20:28:39 +00:00
Mai e9edd8ba37 Channel names are not supposed to end with a dash, pretty ghetto way of doing it but I'm tired af 2020-09-10 20:20:27 +00:00
7 changed files with 652 additions and 310 deletions

View File

@ -16,4 +16,3 @@ You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
""" """

View File

@ -191,6 +191,12 @@ def validate(
if reqjson is None: if reqjson is None:
raise BadRequest("No JSON provided") raise BadRequest("No JSON provided")
try:
if schema["type"]["type"] == "channel_type" and reqjson["name"][-1] == "-":
reqjson["name"] = reqjson["name"][:-1]
except:
pass
try: try:
valid = validator.validate(reqjson) valid = validator.validate(reqjson)
except Exception: except Exception:

View File

@ -94,6 +94,23 @@ async def adduser(ctx, args):
print(f'\tdiscrim: {user["discriminator"]}') print(f'\tdiscrim: {user["discriminator"]}')
async def addbot(ctx, args):
uid, _ = await create_user(args.username, args.email, args.password)
await ctx.db.execute(
"""
UPDATE users
SET bot=True
WHERE id = $1
""",
uid,
)
args.user_id = uid
return await generate_bot_token(ctx, args)
async def set_flag(ctx, args): async def set_flag(ctx, args):
"""Setting a 'staff' flag gives the user access to the Admin API. """Setting a 'staff' flag gives the user access to the Admin API.
Beware of that. Beware of that.
@ -138,7 +155,8 @@ async def generate_bot_token(ctx, args):
) )
if not password_hash: if not password_hash:
return print("cannot find a bot with specified id") print("cannot find a bot with specified id")
return 1
print(make_token(args.user_id, password_hash)) print(make_token(args.user_id, password_hash))
@ -198,6 +216,14 @@ def setup(subparser):
setup_test_parser.set_defaults(func=adduser) setup_test_parser.set_defaults(func=adduser)
addbot_parser = subparser.add_parser("addbot", help="create a bot")
addbot_parser.add_argument("username", help="username of the bot")
addbot_parser.add_argument("email", help="email of the bot")
addbot_parser.add_argument("password", help="password of the bot")
addbot_parser.set_defaults(func=addbot)
setflag_parser = subparser.add_parser( setflag_parser = subparser.add_parser(
"setflag", help="set a flag for a user", description=set_flag.__doc__ "setflag", help="set a flag for a user", description=set_flag.__doc__
) )

View File

@ -93,7 +93,7 @@ def main(config):
async def _ctx_wrapper(fake_app, args): async def _ctx_wrapper(fake_app, args):
app = fake_app.make_app() app = fake_app.make_app()
async with app.app_context(): async with app.app_context():
await args.func(fake_app, args) return await args.func(fake_app, args)
try: try:
if len(argv) < 2: if len(argv) < 2:
@ -107,8 +107,9 @@ def main(config):
init_app_managers(app, init_voice=False) init_app_managers(app, init_voice=False)
args = parser.parse_args() args = parser.parse_args()
loop.run_until_complete(_ctx_wrapper(app, args)) return loop.run_until_complete(_ctx_wrapper(app, args))
except Exception: except Exception:
log.exception("error while running command") log.exception("error while running command")
return 1
finally: finally:
loop.run_until_complete(app.db.close()) loop.run_until_complete(app.db.close())

787
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -19,6 +19,7 @@ pillow = "^8.3.2"
aiohttp = "^3.7.4" aiohttp = "^3.7.4"
zstandard = "^0.15.2" zstandard = "^0.15.2"
winter = {git = "https://gitlab.com/elixire/winter"} winter = {git = "https://gitlab.com/elixire/winter"}
wsproto = "^1.0.0"

View File

@ -19,10 +19,23 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import json import json
import zlib import zlib
import asyncio
import urllib.parse
import collections
from typing import Optional from typing import Optional
import pytest import pytest
import websockets import websockets
from logbook import Logger
from wsproto import WSConnection, ConnectionType
from wsproto.connection import ConnectionState
from wsproto.events import (
Request,
Message,
AcceptConnection,
CloseConnection,
Ping,
)
from litecord.gateway.opcodes import OP from litecord.gateway.opcodes import OP
from litecord.gateway.websocket import decode_etf from litecord.gateway.websocket import decode_etf
@ -31,6 +44,109 @@ from litecord.gateway.websocket import decode_etf
ZLIB_SUFFIX = b"\x00\x00\xff\xff" ZLIB_SUFFIX = b"\x00\x00\xff\xff"
log = Logger("test_websocket")
RcvdWrapper = collections.namedtuple("RcvdWrapper", "code reason")
class AsyncWebsocket:
"""websockets-compatible websocket object"""
def __init__(self, url):
self.url = url
self.ws = WSConnection(ConnectionType.CLIENT)
self.reader, self.writer = None, None
async def send(self, data):
assert self.writer is not None
# wrap all strings in Message
if isinstance(data, str):
data = Message(data=data)
log.debug("sending {} event", type(data))
self.writer.write(self.ws.send(data))
await self.writer.drain()
async def recv(self, *, expect=Message, process_event: bool = True):
# this loop is only done so we reply to pings while also being
# able to receive any other event in the middle.
#
# CloseConnection does not lead us to reading other events, so
# that's why it's left out.
while True:
# if there's already an unprocessed event we can try getting
# it from wsproto first
event = None
for event in self.ws.events():
break
if event is None:
data = await self.reader.read(4096)
assert data # We expect the WebSocket to be closed correctly
self.ws.receive_data(data)
continue
# if we get a ping, reply with pong immediately
# and fetch the next event
if isinstance(event, Ping):
await self.send(event.response())
continue
break
if isinstance(event, CloseConnection):
assert self.ws.state is ConnectionState.REMOTE_CLOSING
await self.send(event.response())
if process_event:
raise websockets.ConnectionClosed(
RcvdWrapper(event.code, event.reason), None
)
if expect is not None and not isinstance(event, expect):
raise AssertionError(
f"Expected {expect!r} websocket event, got {type(event)!r}"
)
# this keeps compatibility with code written for aaugustin/websockets
if expect is Message and process_event:
return event.data
return event
async def close(self, close_code: int, close_reason: str):
log.info("closing connection")
event = CloseConnection(code=close_code, reason=close_reason)
await self.send(event)
self.writer.close()
await self.writer.wait_closed()
self.ws.receive_data(None)
async def connect(self):
parsed = urllib.parse.urlparse(self.url)
if parsed.scheme == "wss":
port = 443
elif parsed.scheme == "ws":
port = 80
else:
raise AssertionError("Invalid url scheme")
host, *rest = parsed.netloc.split(":")
if rest:
port = rest[0]
log.info("connecting to {!r} {}", host, port)
self.reader, self.writer = await asyncio.open_connection(host, port)
path = parsed.path or "/"
target = f"{path}?{parsed.query}" if parsed.query else path
await self.send(Request(host=parsed.netloc, target=target))
await self.recv(expect=AcceptConnection)
async def _recv(conn, *, zlib_stream: bool): async def _recv(conn, *, zlib_stream: bool):
if zlib_stream: if zlib_stream:
try: try:
@ -43,11 +159,16 @@ async def _recv(conn, *, zlib_stream: bool):
zlib_buffer = bytearray() zlib_buffer = bytearray()
while True: while True:
# keep receiving frames until we find the zlib prefix inside # keep receiving frames until we find the zlib prefix inside
msg = await conn.recv() # we set process_event to false so that we get the entire event
zlib_buffer.extend(msg) # instead of only data
if len(msg) < 4 or msg[-4:] != ZLIB_SUFFIX: event = await conn.recv(process_event=False)
zlib_buffer.extend(event.data)
if not event.message_finished:
continue continue
if len(zlib_buffer) < 4 or zlib_buffer[-4:] != ZLIB_SUFFIX:
raise RuntimeError("Finished compressed message without ZLIB suffix")
# NOTE: the message is utf-8 encoded. # NOTE: the message is utf-8 encoded.
msg = conn._zlib_context.decompress(zlib_buffer) msg = conn._zlib_context.decompress(zlib_buffer)
return msg return msg
@ -119,7 +240,10 @@ async def gw_start(
gw_url = f"{gw_url}?v={version}&encoding=json" gw_url = f"{gw_url}?v={version}&encoding=json"
compress = f"&compress={compress}" if compress else "" compress = f"&compress={compress}" if compress else ""
return await websockets.connect(f"{gw_url}{compress}")
ws = AsyncWebsocket(f"{gw_url}{compress}")
await ws.connect()
return ws
@pytest.mark.asyncio @pytest.mark.asyncio
@ -174,8 +298,6 @@ async def test_broken_identify(test_cli_user):
raise AssertionError("Received a JSON message but expected close") raise AssertionError("Received a JSON message but expected close")
except websockets.ConnectionClosed as exc: except websockets.ConnectionClosed as exc:
assert exc.code == 4002 assert exc.code == 4002
finally:
await _close(conn)
@pytest.mark.asyncio @pytest.mark.asyncio