Compare commits

..

No commits in common. "1d4f99f375e031053f1f03c08da9d47b2301953c" and "3c815cf872505da198ba4a4ad9857858a75f42f7" have entirely different histories.

20 changed files with 716 additions and 429 deletions

View File

@ -1,4 +1,4 @@
image: python:3.10-alpine
image: python:3.9-alpine
variables:
PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip"

View File

@ -29,8 +29,7 @@ from litecord.errors import MessageNotFound, Forbidden
from litecord.enums import MessageType, ChannelType, GUILD_CHANS
from litecord.schemas import validate, MESSAGE_CREATE
from litecord.utils import query_tuple_from_args, extract_limit
from litecord.json import pg_set_json
from litecord.utils import pg_set_json, query_tuple_from_args, extract_limit
from litecord.permissions import get_permissions
from litecord.embed.sanitizer import fill_embed

View File

@ -51,14 +51,6 @@ async def api_gateway_bot():
user_id,
)
max_concurrency = await app.db.fetchval(
"""select max_concurrency
from users
where id = $1
""",
user_id,
)
shards = max(int(guild_count / 1000), 1)
# get _ws.session ratelimit
@ -86,7 +78,7 @@ async def api_gateway_bot():
"total": bucket.requests,
"remaining": bucket._tokens,
"reset_after": int(reset_after_ts * 1000),
"max_concurrency": max_concurrency,
"max_concurrency": 1,
},
}
)

View File

@ -17,7 +17,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
from quart import Blueprint, current_app as app, render_template_string, send_file
from quart import Blueprint, current_app as app, render_template_string
from pathlib import Path
bp = Blueprint("static", __name__)
@ -30,10 +30,7 @@ async def static_pages(path):
return "no", 404
static_path = Path.cwd() / Path("static") / path
if static_path.exists():
return await send_file(static_path)
else:
return "not found", 404
return await app.send_static_file(str(static_path))
@bp.route("/")

View File

@ -52,7 +52,7 @@ from litecord.common.messages import (
from litecord.embed.sanitizer import fill_embed, fetch_mediaproxy_img
from litecord.embed.messages import process_url_embed, is_media_url
from litecord.embed.schemas import EmbedURL
from litecord.json import pg_set_json
from litecord.utils import pg_set_json
from litecord.enums import MessageType
from litecord.images import STATIC_IMAGE_MIMES

View File

@ -20,7 +20,7 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
import json
import earl
from litecord.json import LitecordJSONEncoder
from litecord.utils import LitecordJSONEncoder
def encode_json(payload) -> str:

View File

@ -22,7 +22,7 @@ import asyncio
import pprint
import zlib
import time
from typing import List, Dict, Any, Iterable, Optional, Union
from typing import List, Dict, Any, Iterable, Optional
from random import randint
import websockets
@ -66,7 +66,6 @@ from litecord.gateway.schemas import (
)
from litecord.storage import int_
from litecord.blueprints.gateway import get_gw
log = Logger(__name__)
@ -354,8 +353,8 @@ class GatewayWebsocket:
"""Send a packet but just the OP code information is filled in."""
await self.send({"op": op_code, "d": data, "t": None, "s": None})
def _check_ratelimit(self, key: str, ratelimit_key: Any):
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}", exact=True)
def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
@ -397,12 +396,6 @@ class GatewayWebsocket:
such as READY and RESUMED, or events that are replies to
messages in the websocket.
"""
if not self.state:
log.warning(
"can not dispatch {!r} as there is no state in ws {!r}", event, self
)
return
payload = {
"op": OP.DISPATCH,
"t": event.upper(),
@ -530,7 +523,6 @@ class GatewayWebsocket:
"session_id": self.state.session_id,
"_trace": ["transbian"],
"shard": [self.state.current_shard, self.state.shard_count],
"resume_gateway_url": get_gw(),
}
# base_ready and user_ready are normalized as v6. from here onwards
@ -766,8 +758,7 @@ class GatewayWebsocket:
await self.send_op(OP.HEARTBEAT_ACK, None)
async def _connect_ratelimit(self, user_id: Union[int, str]):
log.debug("validating connect ratelimit against {!r}", user_id)
async def _connect_ratelimit(self, user_id: int):
if self._check_ratelimit("connect", user_id):
await self.invalidate_session(False)
raise WebsocketClose(4009, "You are being ratelimited.")
@ -800,15 +791,7 @@ class GatewayWebsocket:
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Authentication failed")
max_concurrency = await self.app.db.fetchval(
"""select max_concurrency
from users
where id = $1
""",
user_id,
)
await self._connect_ratelimit(f"{str(user_id)}%{shard[0]%max_concurrency}")
await self._connect_ratelimit(user_id)
bot = await self.app.db.fetchval(
"""

View File

@ -1,69 +0,0 @@
"""
Litecord
Copyright (C) 2018-2021 Luna Mendes and Litecord Contributors
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, version 3 of the License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import json
from typing import Any
from decimal import Decimal
from uuid import UUID
from dataclasses import asdict, is_dataclass
import quart.json.provider
class LitecordJSONEncoder(json.JSONEncoder):
"""Custom JSON encoder for Litecord. Useful for json.dumps"""
def default(self, value: Any):
if isinstance(value, (Decimal, UUID)):
return str(value)
if is_dataclass(value):
return asdict(value)
if hasattr(value, "to_json"):
return value.to_json
return super().default(self, value)
class LitecordJSONProvider(quart.json.provider.DefaultJSONProvider):
"""Custom JSON provider for Quart."""
def __init__(self, *args, **kwargs):
self.encoder = LitecordJSONEncoder(**kwargs)
def default(self, value: Any):
self.encoder.default(value)
async def pg_set_json(con):
"""Set JSON and JSONB codecs for an asyncpg connection."""
await con.set_type_codec(
"json",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema="pg_catalog",
)
await con.set_type_codec(
"jsonb",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema="pg_catalog",
)

View File

@ -97,8 +97,6 @@ async def ratelimit_handler():
request.discord_api_version = 8
elif rule.rule.startswith("/api/v9"):
request.discord_api_version = 9
elif rule.rule.startswith("/api/v10"):
request.discord_api_version = 10
else:
# default v6 lol
request.discord_api_version = 6

View File

@ -75,13 +75,6 @@ class RatelimitManager:
self._ratelimiters[path] = rtl
def get_ratelimit(self, key: str, exact=False) -> Ratelimit:
def get_ratelimit(self, key: str) -> Ratelimit:
"""Get the :class:`Ratelimit` instance for a given path."""
bucket = self._ratelimiters.get(key)
if bucket:
return bucket
if not exact:
return self.global_bucket
else:
raise AssertionError(f"unknown ratelimit bucket '{key}'")
return self._ratelimiters.get(key, self.global_bucket)

View File

@ -33,7 +33,7 @@ from litecord.blueprints.channel.reactions import (
from litecord.blueprints.user.billing import PLAN_ID_TO_TYPE
from litecord.types import timestamp_
from litecord.json import pg_set_json
from litecord.utils import pg_set_json
log = Logger(__name__)
@ -432,7 +432,7 @@ class Storage:
return {**row, **drow}
elif chan_type == ChannelType.GUILD_VOICE:
voice_channel_data = await self.db.fetchrow(
vrow = await self.db.fetchrow(
"""
SELECT bitrate, user_limit
FROM guild_voice_channels
@ -441,19 +441,7 @@ class Storage:
row["id"],
)
guild_region = await self.db.fetchval(
"""
SELECT region
FROM guilds
WHERE guild.id = $1
""",
int(row["guild_id"]),
)
# see https://gitlab.com/litecord/litecord/-/issues/130
voice_channel_data["rtc_region"] = guild_region
return {**row, **dict(voice_channel_data)}
return {**row, **dict(vrow)}
else:
# this only exists to trick mypy. this codepath is unreachable
raise AssertionError("Unreachable code path.")

View File

@ -18,12 +18,14 @@ along with this program. If not, see <http://www.gnu.org/licenses/>.
"""
import asyncio
import json
import secrets
import datetime
import re
from typing import Any, Iterable, Optional, Sequence, List, Dict, Union
from logbook import Logger
from quart.json import JSONEncoder
from quart import current_app as app
from litecord.common.messages import message_view
@ -154,6 +156,35 @@ def mmh3(inp_str: str, seed: int = 0):
return _u(h1) >> 0
class LitecordJSONEncoder(JSONEncoder):
"""Custom JSON encoder for Litecord."""
def default(self, value: Any):
"""By default, this will try to get the to_json attribute of a given
value being JSON encoded."""
try:
return value.to_json
except AttributeError:
return super().default(value)
async def pg_set_json(con):
"""Set JSON and JSONB codecs for an asyncpg connection."""
await con.set_type_codec(
"json",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema="pg_catalog",
)
await con.set_type_codec(
"jsonb",
encoder=lambda v: json.dumps(v, cls=LitecordJSONEncoder),
decoder=json.loads,
schema="pg_catalog",
)
def yield_chunks(input_list: Sequence[Any], chunk_size: int):
"""Yield successive n-sized chunks from l.

View File

@ -1,3 +0,0 @@
alter table users
add column max_concurrency int not null default 1
check(bot = true or max_concurrency = 1);

View File

@ -94,42 +94,6 @@ async def adduser(ctx, args):
print(f'\tdiscrim: {user["discriminator"]}')
async def set_max_concurrency(ctx, args):
"""Update the `max_concurrency` for a bot.
This can only be set for bot accounts!
"""
if int(args.max_concurrency) < 1:
return print("max_concurrency must be >0")
bot = await ctx.db.fetchval(
"""
select bot
from users
where id = $1
""",
int(args.user_id),
)
if bot is None:
return print("user not found")
if not bot:
return print("user must be a bot")
await ctx.db.execute(
"""
update users
set max_concurrency = $1
where id = $2
""",
int(args.max_concurrency),
int(args.user_id),
)
print(f"OK: set max_concurrency={args.max_concurrency} for {args.user_id}")
async def addbot(ctx, args):
uid, _ = await create_user(args.username, args.email, args.password)
@ -252,17 +216,6 @@ def setup(subparser):
setup_test_parser.set_defaults(func=adduser)
set_max_concurrency_parser = subparser.add_parser(
"set_max_concurrency",
help="set `max_concurrency` for a user",
description=set_max_concurrency.__doc__,
)
set_max_concurrency_parser.add_argument("user_id")
set_max_concurrency_parser.add_argument(
"max_concurrency", help="the `max_concurrency` value to set"
)
set_max_concurrency_parser.set_defaults(func=set_max_concurrency)
addbot_parser = subparser.add_parser("addbot", help="create a bot")
addbot_parser.add_argument("username", help="username of the bot")

800
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -7,19 +7,19 @@ license = "GPLv3-only"
[tool.poetry.dependencies]
python = "^3.9"
bcrypt = "^3.2.2"
itsdangerous = "^2.1.2"
asyncpg = "^0.26.0"
websockets = "^10.3"
bcrypt = "^3.2.0"
itsdangerous = "^1.1.0"
asyncpg = "^0.24.0"
websockets = "^10.0"
Earl-ETF = "^2.1.2"
logbook = "^1.5.3"
Cerberus = "^1.3.4"
quart = "^0.18.0"
pillow = "^9.2.0"
aiohttp = "^3.8.1"
zstandard = "^0.18.0"
quart = {git = "https://gitlab.com/pgjones/quart", rev = "c1ac142c6c51709765045f830b242950099b2295"}
pillow = "^8.3.2"
aiohttp = "^3.7.4"
zstandard = "^0.15.2"
winter = {git = "https://gitlab.com/elixire/winter"}
wsproto = "^1.1.0"
wsproto = "^1.0.0"

8
run.py
View File

@ -105,7 +105,7 @@ from litecord.pubsub.lazy_guild import LazyGuildManager
from litecord.gateway.gateway import websocket_handler
from litecord.json import LitecordJSONProvider
from litecord.utils import LitecordJSONEncoder
# == HACKY PATCH ==
# this MUST be removed once Hypercorn gets py3.10 support.
@ -135,12 +135,12 @@ def make_app():
logging.getLogger("websockets").setLevel(logbook.INFO)
# use our custom json encoder for custom data types
app.json_provider_class = LitecordJSONProvider
app.json_encoder = LitecordJSONEncoder
return app
PREFIXES = ("/api/v6", "/api/v7", "/api/v8", "/api/v9", "/api/v10")
PREFIXES = ("/api/v6", "/api/v7", "/api/v8", "/api/v9")
def set_blueprints(app_):
@ -415,7 +415,7 @@ async def app_after_serving():
# first close all clients, then close db
tasks = app.state_manager.gen_close_tasks()
if tasks:
await asyncio.gather(*tasks)
await asyncio.wait(tasks, loop=app.loop)
app.state_manager.close()

View File

@ -78,27 +78,3 @@ async def test_guild_create(test_cli_user):
resp = await test_cli_user.delete(f"/api/v6/guilds/{guild_id}")
assert resp.status_code == 204
@pytest.mark.asyncio
async def test_guild_nickname(test_cli_user):
guild = await test_cli_user.create_guild()
NEW_NICKNAME = "my awesome nickname"
# stage 1: create
resp = await test_cli_user.patch(
f"/api/v6/guilds/{guild.id}/members/@me/nick",
json={"nick": NEW_NICKNAME},
)
assert resp.status_code == 200
assert (await resp.data).decode() == NEW_NICKNAME
# stage 2: test
resp = await test_cli_user.get(f"/api/v6/guilds/{guild.id}")
assert resp.status_code == 200
fetched_guild = await resp.json
assert fetched_guild["id"] == str(guild.id)
assert fetched_guild["members"][0]["nick"] == NEW_NICKNAME

View File

@ -39,7 +39,6 @@ from wsproto.events import (
from litecord.gateway.opcodes import OP
from litecord.gateway.websocket import decode_etf
from litecord.ratelimits.bucket import Ratelimit
# Z_SYNC_FLUSH suffix
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
@ -217,7 +216,6 @@ async def extract_and_verify_ready(conn, **kwargs):
assert isinstance(data["guilds"], list)
assert isinstance(data["session_id"], str)
assert isinstance(data["_trace"], list)
assert isinstance(data["resume_gateway_url"], str)
if "shard" in data:
assert isinstance(data["shard"], list)
@ -489,46 +487,3 @@ async def test_ready_bot_zlib_stream(test_cli_bot):
await extract_and_verify_ready(conn, zlib_stream=True)
finally:
await _close(conn)
@pytest.mark.asyncio
async def test_max_concurrency(test_cli_bot):
session_ratelimiter = test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"]
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = Ratelimit(1, 5)
try:
gateway_start_coroutines = [gw_start(test_cli_bot.cli) for _ in range(5)]
connections = await asyncio.gather(*gateway_start_coroutines)
# read all HELLOs we send by default
await asyncio.gather(*[_json(conn) for conn in connections])
# make everyone IDENTIFY
await asyncio.gather(
*[
_json_send(
conn,
{"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}},
)
for conn in connections
]
)
await asyncio.sleep(0.5)
# only one of them gets a READY
success_count, error_count = 0, 0
for conn in connections:
try:
data = await _json(conn)
if data["op"] == OP.DISPATCH:
success_count += 1
else:
error_count += 1
finally:
await _close(conn)
assert success_count == 1
assert error_count == 4
finally:
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = session_ratelimiter

16
tox.ini
View File

@ -1,22 +1,22 @@
[tox]
envlist = py3.10
envlist = py3.9
isolated_build = true
[testenv]
ignore_errors = true
deps =
pytest==7.1.2
pytest-asyncio==0.19.0
pytest-cov==3.0.0
flake8==5.0.4
black==22.6.0
mypy==0.971
pytest==6.2.5
pytest-asyncio==0.15.1
pytest-cov==2.12.1
flake8==3.9.2
black==21.6b0
mypy==0.910
pytest-instafail==0.4.2
commands =
python3 ./manage.py migrate
black --check litecord run.py tests manage
flake8 litecord run.py tests manage
pytest --asyncio-mode=auto {posargs:tests}
pytest {posargs:tests}
[flake8]
max-line-length = 88