Compare commits

...

18 Commits

Author SHA1 Message Date
Luna 1d4f99f375 add test for guild nickname setting 2022-08-13 23:14:32 -03:00
Luna 2e346eb350 fix typo 2022-08-13 22:50:30 -03:00
Luna b5b168c388 storage: copy guild.region into channel.rtc_region
close #130
2022-08-13 22:46:31 -03:00
Luna 5bd292422b gateway: do not dispatch raw events when state isn't set
this is possible if a connection quickly disconnects while connecting

close #136
2022-08-13 22:43:09 -03:00
Luna 17e851e95f gateway: add resume_gateway_url to ready
close #143
2022-08-13 22:39:41 -03:00
Luna df78bcaedf add v10 mapping 2022-08-13 19:37:35 -03:00
Luna 6024eee19b add test for max_concurrency
see https://gitlab.com/litecord/litecord/-/merge_requests/82#note_728950904
2022-08-13 19:35:51 -03:00
Luna 992b2cbf1f lint pass 2022-08-13 18:54:42 -03:00
luna de63efff82 Merge branch 'feat/max_concurrency' into 'master'
gateway: add max_concurrency support

See merge request litecord/litecord!82
2022-08-13 21:54:05 +00:00
luna 85c2bc3e18 Merge branch 'master' into 'feat/max_concurrency'
# Conflicts:
#   manage/cmd/users.py
2022-08-13 21:49:22 +00:00
Luna 4270b934f9 add v10 to default version prefix set 2022-08-13 17:04:21 -03:00
Luna a9c3537b88 fix static file send 2022-08-13 17:04:13 -03:00
Luna 1111fffd3a update lockfile 2022-08-13 17:00:43 -03:00
Luna 768611cc4e bump dependencies
use json provider interface for quart
2022-08-13 16:55:13 -03:00
spiral e3f894330d
fix formatting 2021-11-10 00:44:54 -05:00
spiral 2ad6b29175
add manage command set_max_concurrency 2021-11-10 00:29:07 -05:00
spiral bba48f7d0f
fix formatting 2021-09-24 13:40:03 -04:00
spiral b468883e2e
gateway: add max_concurrency support 2021-09-24 13:04:00 -04:00
20 changed files with 430 additions and 717 deletions

View File

@ -1,4 +1,4 @@
image: python:3.9-alpine
image: python:3.10-alpine
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"

View File

@ -29,7 +29,8 @@ from litecord.errors import MessageNotFound, Forbidden
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
from litecord.schemas import validate, MESSAGE_CREATE
from litecord.utils import pg_set_json, query_tuple_from_args, extract_limit
from litecord.utils import query_tuple_from_args, extract_limit
from litecord.json import pg_set_json
from litecord.permissions import get_permissions
from litecord.embed.sanitizer import fill_embed

View File

@ -51,6 +51,14 @@ async def api_gateway_bot():
user_id,
)
max_concurrency = await app.db.fetchval(
"""select max_concurrency
from users
where id = $1
""",
user_id,
)
shards = max(int(guild_count / 1000), 1)
# get _ws.session ratelimit
@ -78,7 +86,7 @@ async def api_gateway_bot():
"total": bucket.requests,
"remaining": bucket._tokens,
"reset_after": int(reset_after_ts * 1000),
"max_concurrency": 1,
"max_concurrency": max_concurrency,
},
}
)

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from quart import Blueprint, current_app as app, render_template_string
from quart import Blueprint, current_app as app, render_template_string, send_file
from pathlib import Path
bp = Blueprint("static", __name__)
@ -30,7 +30,10 @@ async def static_pages(path):
return "no", 404
static_path = Path.cwd() / Path("static") / path
return await app.send_static_file(str(static_path))
if static_path.exists():
return await send_file(static_path)
else:
return "not found", 404
@bp.route("/")

View File

@ -52,7 +52,7 @@ from litecord.common.messages import (
from litecord.embed.sanitizer import fill_embed, fetch_mediaproxy_img
from litecord.embed.messages import process_url_embed, is_media_url
from litecord.embed.schemas import EmbedURL
from litecord.utils import pg_set_json
from litecord.json import pg_set_json
from litecord.enums import MessageType
from litecord.images import STATIC_IMAGE_MIMES

View File

@ -20,7 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import json
import earl
from litecord.utils import LitecordJSONEncoder
from litecord.json import LitecordJSONEncoder
def encode_json(payload) -> str:

View File

@ -22,7 +22,7 @@ import asyncio
import pprint
import zlib
import time
from typing import List, Dict, Any, Iterable, Optional
from typing import List, Dict, Any, Iterable, Optional, Union
from random import randint
import websockets
@ -66,6 +66,7 @@ from litecord.gateway.schemas import (
)
from litecord.storage import int_
from litecord.blueprints.gateway import get_gw
log = Logger(__name__)
@ -353,8 +354,8 @@ class GatewayWebsocket:
"""Send a packet but just the OP code information is filled in."""
await self.send({"op": op_code, "d": data, "t": None, "s": None})
def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
def _check_ratelimit(self, key: str, ratelimit_key: Any):
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}", exact=True)
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
@ -396,6 +397,12 @@ class GatewayWebsocket:
such as READY and RESUMED, or events that are replies to
messages in the websocket.
"""
if not self.state:
log.warning(
"can not dispatch {!r} as there is no state in ws {!r}", event, self
)
return
payload = {
"op": OP.DISPATCH,
"t": event.upper(),
@ -523,6 +530,7 @@ class GatewayWebsocket:
"session_id": self.state.session_id,
"_trace": ["transbian"],
"shard": [self.state.current_shard, self.state.shard_count],
"resume_gateway_url": get_gw(),
}
# base_ready and user_ready are normalized as v6. from here onwards
@ -758,7 +766,8 @@ class GatewayWebsocket:
await self.send_op(OP.HEARTBEAT_ACK, None)
async def _connect_ratelimit(self, user_id: int):
async def _connect_ratelimit(self, user_id: Union[int, str]):
log.debug("validating connect ratelimit against {!r}", user_id)
if self._check_ratelimit("connect", user_id):
await self.invalidate_session(False)
raise WebsocketClose(4009, "You are being ratelimited.")
@ -791,7 +800,15 @@ class GatewayWebsocket:
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Authentication failed")
await self._connect_ratelimit(user_id)
max_concurrency = await self.app.db.fetchval(
"""select max_concurrency
from users
where id = $1
""",
user_id,
)
await self._connect_ratelimit(f"{str(user_id)}%{shard[0]%max_concurrency}")
bot = await self.app.db.fetchval(
"""

69
litecord/json.py Normal file
View File

@ -0,0 +1,69 @@
"""
Litecord
Copyright (C) 2018-2021 Luna Mendes and Litecord Contributors
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import json
from typing import Any
from decimal import Decimal
from uuid import UUID
from dataclasses import asdict, is_dataclass
import quart.json.provider
class LitecordJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder for Litecord. Useful for json.dumps"""
def default(self, value: Any):
if isinstance(value, (Decimal, UUID)):
return str(value)
if is_dataclass(value):
return asdict(value)
if hasattr(value, "to_json"):
return value.to_json
return super().default(self, value)
class LitecordJSONProvider(quart.json.provider.DefaultJSONProvider):
"""Custom JSON provider for Quart."""
def __init__(self, *args, **kwargs):
self.encoder = LitecordJSONEncoder(**kwargs)
def default(self, value: Any):
self.encoder.default(value)
async def pg_set_json(con):
"""Set JSON and JSONB codecs for an asyncpg connection."""
await con.set_type_codec(
"json",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema="pg_catalog",
)
await con.set_type_codec(
"jsonb",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema="pg_catalog",
)

View File

@ -97,6 +97,8 @@ async def ratelimit_handler():
request.discord_api_version = 8
elif rule.rule.startswith("/api/v9"):
request.discord_api_version = 9
elif rule.rule.startswith("/api/v10"):
request.discord_api_version = 10
else:
# default v6 lol
request.discord_api_version = 6

View File

@ -75,6 +75,13 @@ class RatelimitManager:
self._ratelimiters[path] = rtl
def get_ratelimit(self, key: str) -> Ratelimit:
def get_ratelimit(self, key: str, exact=False) -> Ratelimit:
"""Get the :class:`Ratelimit` instance for a given path."""
return self._ratelimiters.get(key, self.global_bucket)
bucket = self._ratelimiters.get(key)
if bucket:
return bucket
if not exact:
return self.global_bucket
else:
raise AssertionError(f"unknown ratelimit bucket '{key}'")

View File

@ -33,7 +33,7 @@ from litecord.blueprints.channel.reactions import (
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
from litecord.types import timestamp_
from litecord.utils import pg_set_json
from litecord.json import pg_set_json
log = Logger(__name__)
@ -432,7 +432,7 @@ class Storage:
return {**row, **drow}
elif chan_type == ChannelType.GUILD_VOICE:
vrow = await self.db.fetchrow(
voice_channel_data = await self.db.fetchrow(
"""
SELECT bitrate, user_limit
FROM guild_voice_channels
@ -441,7 +441,19 @@ class Storage:
row["id"],
)
return {**row, **dict(vrow)}
guild_region = await self.db.fetchval(
"""
SELECT region
FROM guilds
WHERE guild.id = $1
""",
int(row["guild_id"]),
)
# see https://gitlab.com/litecord/litecord/-/issues/130
voice_channel_data["rtc_region"] = guild_region
return {**row, **dict(voice_channel_data)}
else:
# this only exists to trick mypy. this codepath is unreachable
raise AssertionError("Unreachable code path.")

View File

@ -18,14 +18,12 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import asyncio
import json
import secrets
import datetime
import re
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
from logbook import Logger
from quart.json import JSONEncoder
from quart import current_app as app
from litecord.common.messages import message_view
@ -156,35 +154,6 @@ def mmh3(inp_str: str, seed: int = 0):
return _u(h1) >> 0
class LitecordJSONEncoder(JSONEncoder):
"""Custom JSON encoder for Litecord."""
def default(self, value: Any):
"""By default, this will try to get the to_json attribute of a given
value being JSON encoded."""
try:
return value.to_json
except AttributeError:
return super().default(value)
async def pg_set_json(con):
"""Set JSON and JSONB codecs for an asyncpg connection."""
await con.set_type_codec(
"json",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema="pg_catalog",
)
await con.set_type_codec(
"jsonb",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema="pg_catalog",
)
def yield_chunks(input_list: Sequence[Any], chunk_size: int):
"""Yield successive n-sized chunks from l.

View File

@ -0,0 +1,3 @@
alter table users
add column max_concurrency int not null default 1
check(bot = true or max_concurrency = 1);

View File

@ -94,6 +94,42 @@ async def adduser(ctx, args):
print(f'\tdiscrim: {user["discriminator"]}')
async def set_max_concurrency(ctx, args):
"""Update the `max_concurrency` for a bot.
This can only be set for bot accounts!
"""
if int(args.max_concurrency) < 1:
return print("max_concurrency must be >0")
bot = await ctx.db.fetchval(
"""
select bot
from users
where id = $1
""",
int(args.user_id),
)
if bot is None:
return print("user not found")
if not bot:
return print("user must be a bot")
await ctx.db.execute(
"""
update users
set max_concurrency = $1
where id = $2
""",
int(args.max_concurrency),
int(args.user_id),
)
print(f"OK: set max_concurrency={args.max_concurrency} for {args.user_id}")
async def addbot(ctx, args):
uid, _ = await create_user(args.username, args.email, args.password)
@ -216,6 +252,17 @@ def setup(subparser):
setup_test_parser.set_defaults(func=adduser)
set_max_concurrency_parser = subparser.add_parser(
"set_max_concurrency",
help="set `max_concurrency` for a user",
description=set_max_concurrency.__doc__,
)
set_max_concurrency_parser.add_argument("user_id")
set_max_concurrency_parser.add_argument(
"max_concurrency", help="the `max_concurrency` value to set"
)
set_max_concurrency_parser.set_defaults(func=set_max_concurrency)
addbot_parser = subparser.add_parser("addbot", help="create a bot")
addbot_parser.add_argument("username", help="username of the bot")

802
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -7,19 +7,19 @@ license = "GPLv3-only"
[tool.poetry.dependencies]
python = "^3.9"
bcrypt = "^3.2.0"
itsdangerous = "^1.1.0"
asyncpg = "^0.24.0"
websockets = "^10.0"
bcrypt = "^3.2.2"
itsdangerous = "^2.1.2"
asyncpg = "^0.26.0"
websockets = "^10.3"
Earl-ETF = "^2.1.2"
logbook = "^1.5.3"
Cerberus = "^1.3.4"
quart = {git = "https://gitlab.com/pgjones/quart", rev = "c1ac142c6c51709765045f830b242950099b2295"}
pillow = "^8.3.2"
aiohttp = "^3.7.4"
zstandard = "^0.15.2"
quart = "^0.18.0"
pillow = "^9.2.0"
aiohttp = "^3.8.1"
zstandard = "^0.18.0"
winter = {git = "https://gitlab.com/elixire/winter"}
wsproto = "^1.0.0"
wsproto = "^1.1.0"

8
run.py
View File

@ -105,7 +105,7 @@ from litecord.pubsub.lazy_guild import LazyGuildManager
from litecord.gateway.gateway import websocket_handler
from litecord.utils import LitecordJSONEncoder
from litecord.json import LitecordJSONProvider
# == HACKY PATCH ==
# this MUST be removed once Hypercorn gets py3.10 support.
@ -135,12 +135,12 @@ def make_app():
logging.getLogger("websockets").setLevel(logbook.INFO)
# use our custom json encoder for custom data types
app.json_encoder = LitecordJSONEncoder
app.json_provider_class = LitecordJSONProvider
return app
PREFIXES = ("/api/v6", "/api/v7", "/api/v8", "/api/v9")
PREFIXES = ("/api/v6", "/api/v7", "/api/v8", "/api/v9", "/api/v10")
def set_blueprints(app_):
@ -415,7 +415,7 @@ async def app_after_serving():
# first close all clients, then close db
tasks = app.state_manager.gen_close_tasks()
if tasks:
await asyncio.wait(tasks, loop=app.loop)
await asyncio.gather(*tasks)
app.state_manager.close()

View File

@ -78,3 +78,27 @@ async def test_guild_create(test_cli_user):
resp = await test_cli_user.delete(f"/api/v6/guilds/{guild_id}")
assert resp.status_code == 204
@pytest.mark.asyncio
async def test_guild_nickname(test_cli_user):
guild = await test_cli_user.create_guild()
NEW_NICKNAME = "my awesome nickname"
# stage 1: create
resp = await test_cli_user.patch(
f"/api/v6/guilds/{guild.id}/members/@me/nick",
json={"nick": NEW_NICKNAME},
)
assert resp.status_code == 200
assert (await resp.data).decode() == NEW_NICKNAME
# stage 2: test
resp = await test_cli_user.get(f"/api/v6/guilds/{guild.id}")
assert resp.status_code == 200
fetched_guild = await resp.json
assert fetched_guild["id"] == str(guild.id)
assert fetched_guild["members"][0]["nick"] == NEW_NICKNAME

View File

@ -39,6 +39,7 @@ from wsproto.events import (
from litecord.gateway.opcodes import OP
from litecord.gateway.websocket import decode_etf
from litecord.ratelimits.bucket import Ratelimit
# Z_SYNC_FLUSH suffix
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
@ -216,6 +217,7 @@ async def extract_and_verify_ready(conn, **kwargs):
assert isinstance(data["guilds"], list)
assert isinstance(data["session_id"], str)
assert isinstance(data["_trace"], list)
assert isinstance(data["resume_gateway_url"], str)
if "shard" in data:
assert isinstance(data["shard"], list)
@ -487,3 +489,46 @@ async def test_ready_bot_zlib_stream(test_cli_bot):
await extract_and_verify_ready(conn, zlib_stream=True)
finally:
await _close(conn)
@pytest.mark.asyncio
async def test_max_concurrency(test_cli_bot):
session_ratelimiter = test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"]
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = Ratelimit(1, 5)
try:
gateway_start_coroutines = [gw_start(test_cli_bot.cli) for _ in range(5)]
connections = await asyncio.gather(*gateway_start_coroutines)
# read all HELLOs we send by default
await asyncio.gather(*[_json(conn) for conn in connections])
# make everyone IDENTIFY
await asyncio.gather(
*[
_json_send(
conn,
{"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}},
)
for conn in connections
]
)
await asyncio.sleep(0.5)
# only one of them gets a READY
success_count, error_count = 0, 0
for conn in connections:
try:
data = await _json(conn)
if data["op"] == OP.DISPATCH:
success_count += 1
else:
error_count += 1
finally:
await _close(conn)
assert success_count == 1
assert error_count == 4
finally:
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = session_ratelimiter

16
tox.ini
View File

@ -1,22 +1,22 @@
[tox]
envlist = py3.9
envlist = py3.10
isolated_build = true
[testenv]
ignore_errors = true
deps =
pytest==6.2.5
pytest-asyncio==0.15.1
pytest-cov==2.12.1
flake8==3.9.2
black==21.6b0
mypy==0.910
pytest==7.1.2
pytest-asyncio==0.19.0
pytest-cov==3.0.0
flake8==5.0.4
black==22.6.0
mypy==0.971
pytest-instafail==0.4.2
commands =
python3 ./manage.py migrate
black --check litecord run.py tests manage
flake8 litecord run.py tests manage
pytest {posargs:tests}
pytest --asyncio-mode=auto {posargs:tests}
[flake8]
max-line-length = 88