Merge branch 'feat/max_concurrency' into 'master'

gateway: add max_concurrency support

See merge request litecord/litecord!82
This commit is contained in:
spiral 2022-08-13 21:49:24 +00:00
commit 39bc9e4a14
4 changed files with 68 additions and 2 deletions

View File

@ -51,6 +51,14 @@ async def api_gateway_bot():
user_id, user_id,
) )
max_concurrency = await app.db.fetchval(
"""select max_concurrency
from users
where id = $1
""",
user_id,
)
shards = max(int(guild_count / 1000), 1) shards = max(int(guild_count / 1000), 1)
# get _ws.session ratelimit # get _ws.session ratelimit
@ -78,7 +86,7 @@ async def api_gateway_bot():
"total": bucket.requests, "total": bucket.requests,
"remaining": bucket._tokens, "remaining": bucket._tokens,
"reset_after": int(reset_after_ts * 1000), "reset_after": int(reset_after_ts * 1000),
"max_concurrency": 1, "max_concurrency": max_concurrency,
}, },
} }
) )

View File

@ -791,7 +791,15 @@ class GatewayWebsocket:
except (Unauthorized, Forbidden): except (Unauthorized, Forbidden):
raise WebsocketClose(4004, "Authentication failed") raise WebsocketClose(4004, "Authentication failed")
await self._connect_ratelimit(user_id) max_concurrency = await self.app.db.fetchval(
"""select max_concurrency
from users
where id = $1
""",
user_id,
)
await self._connect_ratelimit(f"{str(user_id)}%{shard[0]%max_concurrency}")
bot = await self.app.db.fetchval( bot = await self.app.db.fetchval(
""" """

View File

@ -0,0 +1,3 @@
alter table users
add column max_concurrency int not null default 1
check(bot = true or max_concurrency = 1);

View File

@ -94,6 +94,42 @@ async def adduser(ctx, args):
print(f'\tdiscrim: {user["discriminator"]}') print(f'\tdiscrim: {user["discriminator"]}')
async def set_max_concurrency(ctx, args):
"""Update the `max_concurrency` for a bot.
This can only be set for bot accounts!
"""
if int(args.max_concurrency) < 1:
return print("max_concurrency must be >0")
bot = await ctx.db.fetchval(
"""
select bot
from users
where id = $1
""",
int(args.user_id),
)
if bot == None:
return print("user not found")
if bot == False:
return print("user must be a bot")
await ctx.db.execute(
"""
update users
set max_concurrency = $1
where id = $2
""",
int(args.max_concurrency),
int(args.user_id),
)
print(f"OK: set max_concurrency={args.max_concurrency} for {args.user_id}")
async def addbot(ctx, args): async def addbot(ctx, args):
uid, _ = await create_user(args.username, args.email, args.password) uid, _ = await create_user(args.username, args.email, args.password)
@ -216,6 +252,17 @@ def setup(subparser):
setup_test_parser.set_defaults(func=adduser) setup_test_parser.set_defaults(func=adduser)
set_max_concurrency_parser = subparser.add_parser(
"set_max_concurrency",
help="set `max_concurrency` for a user",
description=set_max_concurrency.__doc__,
)
set_max_concurrency_parser.add_argument("user_id")
set_max_concurrency_parser.add_argument(
"max_concurrency", help="the `max_concurrency` value to set"
)
set_max_concurrency_parser.set_defaults(func=set_max_concurrency)
addbot_parser = subparser.add_parser("addbot", help="create a bot") addbot_parser = subparser.add_parser("addbot", help="create a bot")
addbot_parser.add_argument("username", help="username of the bot") addbot_parser.add_argument("username", help="username of the bot")