mirror of https://gitlab.com/litecord/litecord.git
Compare commits
18 Commits
3c815cf872
...
1d4f99f375
| Author | SHA1 | Date |
|---|---|---|
|
|
1d4f99f375 | |
|
|
2e346eb350 | |
|
|
b5b168c388 | |
|
|
5bd292422b | |
|
|
17e851e95f | |
|
|
df78bcaedf | |
|
|
6024eee19b | |
|
|
992b2cbf1f | |
|
|
de63efff82 | |
|
|
85c2bc3e18 | |
|
|
4270b934f9 | |
|
|
a9c3537b88 | |
|
|
1111fffd3a | |
|
|
768611cc4e | |
|
|
e3f894330d | |
|
|
2ad6b29175 | |
|
|
bba48f7d0f | |
|
|
b468883e2e |
|
|
@ -1,4 +1,4 @@
|
||||||
image: python:3.9-alpine
|
image: python:3.10-alpine
|
||||||
|
|
||||||
variables:
|
variables:
|
||||||
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
|
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,8 @@ from litecord.errors import MessageNotFound, Forbidden
|
||||||
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
|
||||||
|
|
||||||
from litecord.schemas import validate, MESSAGE_CREATE
|
from litecord.schemas import validate, MESSAGE_CREATE
|
||||||
from litecord.utils import pg_set_json, query_tuple_from_args, extract_limit
|
from litecord.utils import query_tuple_from_args, extract_limit
|
||||||
|
from litecord.json import pg_set_json
|
||||||
from litecord.permissions import get_permissions
|
from litecord.permissions import get_permissions
|
||||||
|
|
||||||
from litecord.embed.sanitizer import fill_embed
|
from litecord.embed.sanitizer import fill_embed
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,14 @@ async def api_gateway_bot():
|
||||||
user_id,
|
user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
max_concurrency = await app.db.fetchval(
|
||||||
|
"""select max_concurrency
|
||||||
|
from users
|
||||||
|
where id = $1
|
||||||
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
shards = max(int(guild_count / 1000), 1)
|
shards = max(int(guild_count / 1000), 1)
|
||||||
|
|
||||||
# get _ws.session ratelimit
|
# get _ws.session ratelimit
|
||||||
|
|
@ -78,7 +86,7 @@ async def api_gateway_bot():
|
||||||
"total": bucket.requests,
|
"total": bucket.requests,
|
||||||
"remaining": bucket._tokens,
|
"remaining": bucket._tokens,
|
||||||
"reset_after": int(reset_after_ts * 1000),
|
"reset_after": int(reset_after_ts * 1000),
|
||||||
"max_concurrency": 1,
|
"max_concurrency": max_concurrency,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from quart import Blueprint, current_app as app, render_template_string
|
from quart import Blueprint, current_app as app, render_template_string, send_file
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
bp = Blueprint("static", __name__)
|
bp = Blueprint("static", __name__)
|
||||||
|
|
@ -30,7 +30,10 @@ async def static_pages(path):
|
||||||
return "no", 404
|
return "no", 404
|
||||||
|
|
||||||
static_path = Path.cwd() / Path("static") / path
|
static_path = Path.cwd() / Path("static") / path
|
||||||
return await app.send_static_file(str(static_path))
|
if static_path.exists():
|
||||||
|
return await send_file(static_path)
|
||||||
|
else:
|
||||||
|
return "not found", 404
|
||||||
|
|
||||||
|
|
||||||
@bp.route("/")
|
@bp.route("/")
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ from litecord.common.messages import (
|
||||||
from litecord.embed.sanitizer import fill_embed, fetch_mediaproxy_img
|
from litecord.embed.sanitizer import fill_embed, fetch_mediaproxy_img
|
||||||
from litecord.embed.messages import process_url_embed, is_media_url
|
from litecord.embed.messages import process_url_embed, is_media_url
|
||||||
from litecord.embed.schemas import EmbedURL
|
from litecord.embed.schemas import EmbedURL
|
||||||
from litecord.utils import pg_set_json
|
from litecord.json import pg_set_json
|
||||||
from litecord.enums import MessageType
|
from litecord.enums import MessageType
|
||||||
from litecord.images import STATIC_IMAGE_MIMES
|
from litecord.images import STATIC_IMAGE_MIMES
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
import json
|
import json
|
||||||
import earl
|
import earl
|
||||||
|
|
||||||
from litecord.utils import LitecordJSONEncoder
|
from litecord.json import LitecordJSONEncoder
|
||||||
|
|
||||||
|
|
||||||
def encode_json(payload) -> str:
|
def encode_json(payload) -> str:
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ import asyncio
|
||||||
import pprint
|
import pprint
|
||||||
import zlib
|
import zlib
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Any, Iterable, Optional
|
from typing import List, Dict, Any, Iterable, Optional, Union
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
@ -66,6 +66,7 @@ from litecord.gateway.schemas import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.storage import int_
|
from litecord.storage import int_
|
||||||
|
from litecord.blueprints.gateway import get_gw
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -353,8 +354,8 @@ class GatewayWebsocket:
|
||||||
"""Send a packet but just the OP code information is filled in."""
|
"""Send a packet but just the OP code information is filled in."""
|
||||||
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
||||||
|
|
||||||
def _check_ratelimit(self, key: str, ratelimit_key):
|
def _check_ratelimit(self, key: str, ratelimit_key: Any):
|
||||||
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
|
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}", exact=True)
|
||||||
bucket = ratelimit.get_bucket(ratelimit_key)
|
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||||
return bucket.update_rate_limit()
|
return bucket.update_rate_limit()
|
||||||
|
|
||||||
|
|
@ -396,6 +397,12 @@ class GatewayWebsocket:
|
||||||
such as READY and RESUMED, or events that are replies to
|
such as READY and RESUMED, or events that are replies to
|
||||||
messages in the websocket.
|
messages in the websocket.
|
||||||
"""
|
"""
|
||||||
|
if not self.state:
|
||||||
|
log.warning(
|
||||||
|
"can not dispatch {!r} as there is no state in ws {!r}", event, self
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
"op": OP.DISPATCH,
|
"op": OP.DISPATCH,
|
||||||
"t": event.upper(),
|
"t": event.upper(),
|
||||||
|
|
@ -523,6 +530,7 @@ class GatewayWebsocket:
|
||||||
"session_id": self.state.session_id,
|
"session_id": self.state.session_id,
|
||||||
"_trace": ["transbian"],
|
"_trace": ["transbian"],
|
||||||
"shard": [self.state.current_shard, self.state.shard_count],
|
"shard": [self.state.current_shard, self.state.shard_count],
|
||||||
|
"resume_gateway_url": get_gw(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# base_ready and user_ready are normalized as v6. from here onwards
|
# base_ready and user_ready are normalized as v6. from here onwards
|
||||||
|
|
@ -758,7 +766,8 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
await self.send_op(OP.HEARTBEAT_ACK, None)
|
await self.send_op(OP.HEARTBEAT_ACK, None)
|
||||||
|
|
||||||
async def _connect_ratelimit(self, user_id: int):
|
async def _connect_ratelimit(self, user_id: Union[int, str]):
|
||||||
|
log.debug("validating connect ratelimit against {!r}", user_id)
|
||||||
if self._check_ratelimit("connect", user_id):
|
if self._check_ratelimit("connect", user_id):
|
||||||
await self.invalidate_session(False)
|
await self.invalidate_session(False)
|
||||||
raise WebsocketClose(4009, "You are being ratelimited.")
|
raise WebsocketClose(4009, "You are being ratelimited.")
|
||||||
|
|
@ -791,7 +800,15 @@ class GatewayWebsocket:
|
||||||
except (Unauthorized, Forbidden):
|
except (Unauthorized, Forbidden):
|
||||||
raise WebsocketClose(4004, "Authentication failed")
|
raise WebsocketClose(4004, "Authentication failed")
|
||||||
|
|
||||||
await self._connect_ratelimit(user_id)
|
max_concurrency = await self.app.db.fetchval(
|
||||||
|
"""select max_concurrency
|
||||||
|
from users
|
||||||
|
where id = $1
|
||||||
|
""",
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._connect_ratelimit(f"{str(user_id)}%{shard[0]%max_concurrency}")
|
||||||
|
|
||||||
bot = await self.app.db.fetchval(
|
bot = await self.app.db.fetchval(
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""
|
||||||
|
|
||||||
|
Litecord
|
||||||
|
Copyright (C) 2018-2021 Luna Mendes and Litecord Contributors
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, version 3 of the License.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
from decimal import Decimal
|
||||||
|
from uuid import UUID
|
||||||
|
from dataclasses import asdict, is_dataclass
|
||||||
|
|
||||||
|
import quart.json.provider
|
||||||
|
|
||||||
|
|
||||||
|
class LitecordJSONEncoder(json.JSONEncoder):
|
||||||
|
"""Custom JSON encoder for Litecord. Useful for json.dumps"""
|
||||||
|
|
||||||
|
def default(self, value: Any):
|
||||||
|
if isinstance(value, (Decimal, UUID)):
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
if is_dataclass(value):
|
||||||
|
return asdict(value)
|
||||||
|
|
||||||
|
if hasattr(value, "to_json"):
|
||||||
|
return value.to_json
|
||||||
|
|
||||||
|
return super().default(self, value)
|
||||||
|
|
||||||
|
|
||||||
|
class LitecordJSONProvider(quart.json.provider.DefaultJSONProvider):
|
||||||
|
"""Custom JSON provider for Quart."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
self.encoder = LitecordJSONEncoder(**kwargs)
|
||||||
|
|
||||||
|
def default(self, value: Any):
|
||||||
|
self.encoder.default(value)
|
||||||
|
|
||||||
|
|
||||||
|
async def pg_set_json(con):
|
||||||
|
"""Set JSON and JSONB codecs for an asyncpg connection."""
|
||||||
|
await con.set_type_codec(
|
||||||
|
"json",
|
||||||
|
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||||
|
decoder=json.loads,
|
||||||
|
schema="pg_catalog",
|
||||||
|
)
|
||||||
|
|
||||||
|
await con.set_type_codec(
|
||||||
|
"jsonb",
|
||||||
|
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
||||||
|
decoder=json.loads,
|
||||||
|
schema="pg_catalog",
|
||||||
|
)
|
||||||
|
|
@ -97,6 +97,8 @@ async def ratelimit_handler():
|
||||||
request.discord_api_version = 8
|
request.discord_api_version = 8
|
||||||
elif rule.rule.startswith("/api/v9"):
|
elif rule.rule.startswith("/api/v9"):
|
||||||
request.discord_api_version = 9
|
request.discord_api_version = 9
|
||||||
|
elif rule.rule.startswith("/api/v10"):
|
||||||
|
request.discord_api_version = 10
|
||||||
else:
|
else:
|
||||||
# default v6 lol
|
# default v6 lol
|
||||||
request.discord_api_version = 6
|
request.discord_api_version = 6
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,13 @@ class RatelimitManager:
|
||||||
|
|
||||||
self._ratelimiters[path] = rtl
|
self._ratelimiters[path] = rtl
|
||||||
|
|
||||||
def get_ratelimit(self, key: str) -> Ratelimit:
|
def get_ratelimit(self, key: str, exact=False) -> Ratelimit:
|
||||||
"""Get the :class:`Ratelimit` instance for a given path."""
|
"""Get the :class:`Ratelimit` instance for a given path."""
|
||||||
return self._ratelimiters.get(key, self.global_bucket)
|
bucket = self._ratelimiters.get(key)
|
||||||
|
if bucket:
|
||||||
|
return bucket
|
||||||
|
|
||||||
|
if not exact:
|
||||||
|
return self.global_bucket
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"unknown ratelimit bucket '{key}'")
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ from litecord.blueprints.channel.reactions import (
|
||||||
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
|
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
|
||||||
|
|
||||||
from litecord.types import timestamp_
|
from litecord.types import timestamp_
|
||||||
from litecord.utils import pg_set_json
|
from litecord.json import pg_set_json
|
||||||
|
|
||||||
log = Logger(__name__)
|
log = Logger(__name__)
|
||||||
|
|
||||||
|
|
@ -432,7 +432,7 @@ class Storage:
|
||||||
|
|
||||||
return {**row, **drow}
|
return {**row, **drow}
|
||||||
elif chan_type == ChannelType.GUILD_VOICE:
|
elif chan_type == ChannelType.GUILD_VOICE:
|
||||||
vrow = await self.db.fetchrow(
|
voice_channel_data = await self.db.fetchrow(
|
||||||
"""
|
"""
|
||||||
SELECT bitrate, user_limit
|
SELECT bitrate, user_limit
|
||||||
FROM guild_voice_channels
|
FROM guild_voice_channels
|
||||||
|
|
@ -441,7 +441,19 @@ class Storage:
|
||||||
row["id"],
|
row["id"],
|
||||||
)
|
)
|
||||||
|
|
||||||
return {**row, **dict(vrow)}
|
guild_region = await self.db.fetchval(
|
||||||
|
"""
|
||||||
|
SELECT region
|
||||||
|
FROM guilds
|
||||||
|
WHERE guild.id = $1
|
||||||
|
""",
|
||||||
|
int(row["guild_id"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# see https://gitlab.com/litecord/litecord/-/issues/130
|
||||||
|
voice_channel_data["rtc_region"] = guild_region
|
||||||
|
|
||||||
|
return {**row, **dict(voice_channel_data)}
|
||||||
else:
|
else:
|
||||||
# this only exists to trick mypy. this codepath is unreachable
|
# this only exists to trick mypy. this codepath is unreachable
|
||||||
raise AssertionError("Unreachable code path.")
|
raise AssertionError("Unreachable code path.")
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import secrets
|
import secrets
|
||||||
import datetime
|
import datetime
|
||||||
import re
|
import re
|
||||||
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
|
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
|
||||||
|
|
||||||
from logbook import Logger
|
from logbook import Logger
|
||||||
from quart.json import JSONEncoder
|
|
||||||
from quart import current_app as app
|
from quart import current_app as app
|
||||||
|
|
||||||
from litecord.common.messages import message_view
|
from litecord.common.messages import message_view
|
||||||
|
|
@ -156,35 +154,6 @@ def mmh3(inp_str: str, seed: int = 0):
|
||||||
return _u(h1) >> 0
|
return _u(h1) >> 0
|
||||||
|
|
||||||
|
|
||||||
class LitecordJSONEncoder(JSONEncoder):
|
|
||||||
"""Custom JSON encoder for Litecord."""
|
|
||||||
|
|
||||||
def default(self, value: Any):
|
|
||||||
"""By default, this will try to get the to_json attribute of a given
|
|
||||||
value being JSON encoded."""
|
|
||||||
try:
|
|
||||||
return value.to_json
|
|
||||||
except AttributeError:
|
|
||||||
return super().default(value)
|
|
||||||
|
|
||||||
|
|
||||||
async def pg_set_json(con):
|
|
||||||
"""Set JSON and JSONB codecs for an asyncpg connection."""
|
|
||||||
await con.set_type_codec(
|
|
||||||
"json",
|
|
||||||
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
|
||||||
decoder=json.loads,
|
|
||||||
schema="pg_catalog",
|
|
||||||
)
|
|
||||||
|
|
||||||
await con.set_type_codec(
|
|
||||||
"jsonb",
|
|
||||||
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
|
|
||||||
decoder=json.loads,
|
|
||||||
schema="pg_catalog",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def yield_chunks(input_list: Sequence[Any], chunk_size: int):
|
def yield_chunks(input_list: Sequence[Any], chunk_size: int):
|
||||||
"""Yield successive n-sized chunks from l.
|
"""Yield successive n-sized chunks from l.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
alter table users
|
||||||
|
add column max_concurrency int not null default 1
|
||||||
|
check(bot = true or max_concurrency = 1);
|
||||||
|
|
@ -94,6 +94,42 @@ async def adduser(ctx, args):
|
||||||
print(f'\tdiscrim: {user["discriminator"]}')
|
print(f'\tdiscrim: {user["discriminator"]}')
|
||||||
|
|
||||||
|
|
||||||
|
async def set_max_concurrency(ctx, args):
|
||||||
|
"""Update the `max_concurrency` for a bot.
|
||||||
|
This can only be set for bot accounts!
|
||||||
|
"""
|
||||||
|
|
||||||
|
if int(args.max_concurrency) < 1:
|
||||||
|
return print("max_concurrency must be >0")
|
||||||
|
|
||||||
|
bot = await ctx.db.fetchval(
|
||||||
|
"""
|
||||||
|
select bot
|
||||||
|
from users
|
||||||
|
where id = $1
|
||||||
|
""",
|
||||||
|
int(args.user_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
if bot is None:
|
||||||
|
return print("user not found")
|
||||||
|
|
||||||
|
if not bot:
|
||||||
|
return print("user must be a bot")
|
||||||
|
|
||||||
|
await ctx.db.execute(
|
||||||
|
"""
|
||||||
|
update users
|
||||||
|
set max_concurrency = $1
|
||||||
|
where id = $2
|
||||||
|
""",
|
||||||
|
int(args.max_concurrency),
|
||||||
|
int(args.user_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"OK: set max_concurrency={args.max_concurrency} for {args.user_id}")
|
||||||
|
|
||||||
|
|
||||||
async def addbot(ctx, args):
|
async def addbot(ctx, args):
|
||||||
uid, _ = await create_user(args.username, args.email, args.password)
|
uid, _ = await create_user(args.username, args.email, args.password)
|
||||||
|
|
||||||
|
|
@ -216,6 +252,17 @@ def setup(subparser):
|
||||||
|
|
||||||
setup_test_parser.set_defaults(func=adduser)
|
setup_test_parser.set_defaults(func=adduser)
|
||||||
|
|
||||||
|
set_max_concurrency_parser = subparser.add_parser(
|
||||||
|
"set_max_concurrency",
|
||||||
|
help="set `max_concurrency` for a user",
|
||||||
|
description=set_max_concurrency.__doc__,
|
||||||
|
)
|
||||||
|
set_max_concurrency_parser.add_argument("user_id")
|
||||||
|
set_max_concurrency_parser.add_argument(
|
||||||
|
"max_concurrency", help="the `max_concurrency` value to set"
|
||||||
|
)
|
||||||
|
set_max_concurrency_parser.set_defaults(func=set_max_concurrency)
|
||||||
|
|
||||||
addbot_parser = subparser.add_parser("addbot", help="create a bot")
|
addbot_parser = subparser.add_parser("addbot", help="create a bot")
|
||||||
|
|
||||||
addbot_parser.add_argument("username", help="username of the bot")
|
addbot_parser.add_argument("username", help="username of the bot")
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -7,19 +7,19 @@ license = "GPLv3-only"
|
||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = "^3.9"
|
python = "^3.9"
|
||||||
bcrypt = "^3.2.0"
|
bcrypt = "^3.2.2"
|
||||||
itsdangerous = "^1.1.0"
|
itsdangerous = "^2.1.2"
|
||||||
asyncpg = "^0.24.0"
|
asyncpg = "^0.26.0"
|
||||||
websockets = "^10.0"
|
websockets = "^10.3"
|
||||||
Earl-ETF = "^2.1.2"
|
Earl-ETF = "^2.1.2"
|
||||||
logbook = "^1.5.3"
|
logbook = "^1.5.3"
|
||||||
Cerberus = "^1.3.4"
|
Cerberus = "^1.3.4"
|
||||||
quart = {git = "https://gitlab.com/pgjones/quart", rev = "c1ac142c6c51709765045f830b242950099b2295"}
|
quart = "^0.18.0"
|
||||||
pillow = "^8.3.2"
|
pillow = "^9.2.0"
|
||||||
aiohttp = "^3.7.4"
|
aiohttp = "^3.8.1"
|
||||||
zstandard = "^0.15.2"
|
zstandard = "^0.18.0"
|
||||||
winter = {git = "https://gitlab.com/elixire/winter"}
|
winter = {git = "https://gitlab.com/elixire/winter"}
|
||||||
wsproto = "^1.0.0"
|
wsproto = "^1.1.0"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
8
run.py
8
run.py
|
|
@ -105,7 +105,7 @@ from litecord.pubsub.lazy_guild import LazyGuildManager
|
||||||
|
|
||||||
from litecord.gateway.gateway import websocket_handler
|
from litecord.gateway.gateway import websocket_handler
|
||||||
|
|
||||||
from litecord.utils import LitecordJSONEncoder
|
from litecord.json import LitecordJSONProvider
|
||||||
|
|
||||||
# == HACKY PATCH ==
|
# == HACKY PATCH ==
|
||||||
# this MUST be removed once Hypercorn gets py3.10 support.
|
# this MUST be removed once Hypercorn gets py3.10 support.
|
||||||
|
|
@ -135,12 +135,12 @@ def make_app():
|
||||||
logging.getLogger("websockets").setLevel(logbook.INFO)
|
logging.getLogger("websockets").setLevel(logbook.INFO)
|
||||||
|
|
||||||
# use our custom json encoder for custom data types
|
# use our custom json encoder for custom data types
|
||||||
app.json_encoder = LitecordJSONEncoder
|
app.json_provider_class = LitecordJSONProvider
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
PREFIXES = ("/api/v6", "/api/v7", "/api/v8", "/api/v9")
|
PREFIXES = ("/api/v6", "/api/v7", "/api/v8", "/api/v9", "/api/v10")
|
||||||
|
|
||||||
|
|
||||||
def set_blueprints(app_):
|
def set_blueprints(app_):
|
||||||
|
|
@ -415,7 +415,7 @@ async def app_after_serving():
|
||||||
# first close all clients, then close db
|
# first close all clients, then close db
|
||||||
tasks = app.state_manager.gen_close_tasks()
|
tasks = app.state_manager.gen_close_tasks()
|
||||||
if tasks:
|
if tasks:
|
||||||
await asyncio.wait(tasks, loop=app.loop)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
app.state_manager.close()
|
app.state_manager.close()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,3 +78,27 @@ async def test_guild_create(test_cli_user):
|
||||||
resp = await test_cli_user.delete(f"/api/v6/guilds/{guild_id}")
|
resp = await test_cli_user.delete(f"/api/v6/guilds/{guild_id}")
|
||||||
|
|
||||||
assert resp.status_code == 204
|
assert resp.status_code == 204
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_guild_nickname(test_cli_user):
|
||||||
|
guild = await test_cli_user.create_guild()
|
||||||
|
|
||||||
|
NEW_NICKNAME = "my awesome nickname"
|
||||||
|
# stage 1: create
|
||||||
|
resp = await test_cli_user.patch(
|
||||||
|
f"/api/v6/guilds/{guild.id}/members/@me/nick",
|
||||||
|
json={"nick": NEW_NICKNAME},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert (await resp.data).decode() == NEW_NICKNAME
|
||||||
|
|
||||||
|
# stage 2: test
|
||||||
|
resp = await test_cli_user.get(f"/api/v6/guilds/{guild.id}")
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
fetched_guild = await resp.json
|
||||||
|
|
||||||
|
assert fetched_guild["id"] == str(guild.id)
|
||||||
|
assert fetched_guild["members"][0]["nick"] == NEW_NICKNAME
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ from wsproto.events import (
|
||||||
|
|
||||||
from litecord.gateway.opcodes import OP
|
from litecord.gateway.opcodes import OP
|
||||||
from litecord.gateway.websocket import decode_etf
|
from litecord.gateway.websocket import decode_etf
|
||||||
|
from litecord.ratelimits.bucket import Ratelimit
|
||||||
|
|
||||||
# Z_SYNC_FLUSH suffix
|
# Z_SYNC_FLUSH suffix
|
||||||
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
||||||
|
|
@ -216,6 +217,7 @@ async def extract_and_verify_ready(conn, **kwargs):
|
||||||
assert isinstance(data["guilds"], list)
|
assert isinstance(data["guilds"], list)
|
||||||
assert isinstance(data["session_id"], str)
|
assert isinstance(data["session_id"], str)
|
||||||
assert isinstance(data["_trace"], list)
|
assert isinstance(data["_trace"], list)
|
||||||
|
assert isinstance(data["resume_gateway_url"], str)
|
||||||
|
|
||||||
if "shard" in data:
|
if "shard" in data:
|
||||||
assert isinstance(data["shard"], list)
|
assert isinstance(data["shard"], list)
|
||||||
|
|
@ -487,3 +489,46 @@ async def test_ready_bot_zlib_stream(test_cli_bot):
|
||||||
await extract_and_verify_ready(conn, zlib_stream=True)
|
await extract_and_verify_ready(conn, zlib_stream=True)
|
||||||
finally:
|
finally:
|
||||||
await _close(conn)
|
await _close(conn)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_concurrency(test_cli_bot):
|
||||||
|
session_ratelimiter = test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"]
|
||||||
|
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = Ratelimit(1, 5)
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
gateway_start_coroutines = [gw_start(test_cli_bot.cli) for _ in range(5)]
|
||||||
|
connections = await asyncio.gather(*gateway_start_coroutines)
|
||||||
|
|
||||||
|
# read all HELLOs we send by default
|
||||||
|
await asyncio.gather(*[_json(conn) for conn in connections])
|
||||||
|
|
||||||
|
# make everyone IDENTIFY
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
_json_send(
|
||||||
|
conn,
|
||||||
|
{"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}},
|
||||||
|
)
|
||||||
|
for conn in connections
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# only one of them gets a READY
|
||||||
|
success_count, error_count = 0, 0
|
||||||
|
for conn in connections:
|
||||||
|
try:
|
||||||
|
data = await _json(conn)
|
||||||
|
if data["op"] == OP.DISPATCH:
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
error_count += 1
|
||||||
|
finally:
|
||||||
|
await _close(conn)
|
||||||
|
assert success_count == 1
|
||||||
|
assert error_count == 4
|
||||||
|
|
||||||
|
finally:
|
||||||
|
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = session_ratelimiter
|
||||||
|
|
|
||||||
16
tox.ini
16
tox.ini
|
|
@ -1,22 +1,22 @@
|
||||||
[tox]
|
[tox]
|
||||||
envlist = py3.9
|
envlist = py3.10
|
||||||
isolated_build = true
|
isolated_build = true
|
||||||
|
|
||||||
[testenv]
|
[testenv]
|
||||||
ignore_errors = true
|
ignore_errors = true
|
||||||
deps =
|
deps =
|
||||||
pytest==6.2.5
|
pytest==7.1.2
|
||||||
pytest-asyncio==0.15.1
|
pytest-asyncio==0.19.0
|
||||||
pytest-cov==2.12.1
|
pytest-cov==3.0.0
|
||||||
flake8==3.9.2
|
flake8==5.0.4
|
||||||
black==21.6b0
|
black==22.6.0
|
||||||
mypy==0.910
|
mypy==0.971
|
||||||
pytest-instafail==0.4.2
|
pytest-instafail==0.4.2
|
||||||
commands =
|
commands =
|
||||||
python3 ./manage.py migrate
|
python3 ./manage.py migrate
|
||||||
black --check litecord run.py tests manage
|
black --check litecord run.py tests manage
|
||||||
flake8 litecord run.py tests manage
|
flake8 litecord run.py tests manage
|
||||||
pytest {posargs:tests}
|
pytest --asyncio-mode=auto {posargs:tests}
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
max-line-length = 88
|
max-line-length = 88
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue