mirror of https://gitlab.com/litecord/litecord.git
Merge branch 'feat/max_concurrency' into 'master'
gateway: add max_concurrency support See merge request litecord/litecord!82
This commit is contained in:
commit
d24df5a1ea
|
|
@ -51,6 +51,14 @@ async def api_gateway_bot():
|
|||
user_id,
|
||||
)
|
||||
|
||||
max_concurrency = await app.db.fetchval(
|
||||
"""select max_concurrency
|
||||
from users
|
||||
where id = $1
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
shards = max(int(guild_count / 1000), 1)
|
||||
|
||||
# get _ws.session ratelimit
|
||||
|
|
@ -78,7 +86,7 @@ async def api_gateway_bot():
|
|||
"total": bucket.requests,
|
||||
"remaining": bucket._tokens,
|
||||
"reset_after": int(reset_after_ts * 1000),
|
||||
"max_concurrency": 1,
|
||||
"max_concurrency": max_concurrency,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -791,7 +791,15 @@ class GatewayWebsocket:
|
|||
except (Unauthorized, Forbidden):
|
||||
raise WebsocketClose(4004, "Authentication failed")
|
||||
|
||||
await self._connect_ratelimit(user_id)
|
||||
max_concurrency = await self.app.db.fetchval(
|
||||
"""select max_concurrency
|
||||
from users
|
||||
where id = $1
|
||||
""",
|
||||
user_id,
|
||||
)
|
||||
|
||||
await self._connect_ratelimit(f"{str(user_id)}%{shard[0]%max_concurrency}")
|
||||
|
||||
bot = await self.app.db.fetchval(
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
alter table users
|
||||
add column max_concurrency int not null default 1
|
||||
check(bot = true or max_concurrency = 1);
|
||||
|
|
@ -94,6 +94,42 @@ async def adduser(ctx, args):
|
|||
print(f'\tdiscrim: {user["discriminator"]}')
|
||||
|
||||
|
||||
async def set_max_concurrency(ctx, args):
|
||||
"""Update the `max_concurrency` for a bot.
|
||||
This can only be set for bot accounts!
|
||||
"""
|
||||
|
||||
if int(args.max_concurrency) < 1:
|
||||
return print("max_concurrency must be >0")
|
||||
|
||||
bot = await ctx.db.fetchval(
|
||||
"""
|
||||
select bot
|
||||
from users
|
||||
where id = $1
|
||||
""",
|
||||
int(args.user_id),
|
||||
)
|
||||
|
||||
if bot == None:
|
||||
return print("user not found")
|
||||
|
||||
if bot == False:
|
||||
return print("user must be a bot")
|
||||
|
||||
await ctx.db.execute(
|
||||
"""
|
||||
update users
|
||||
set max_concurrency = $1
|
||||
where id = $2
|
||||
""",
|
||||
int(args.max_concurrency),
|
||||
int(args.user_id),
|
||||
)
|
||||
|
||||
print(f"OK: set max_concurrency={args.max_concurrency} for {args.user_id}")
|
||||
|
||||
|
||||
async def set_flag(ctx, args):
|
||||
"""Setting a 'staff' flag gives the user access to the Admin API.
|
||||
Beware of that.
|
||||
|
|
@ -198,6 +234,17 @@ def setup(subparser):
|
|||
|
||||
setup_test_parser.set_defaults(func=adduser)
|
||||
|
||||
set_max_concurrency_parser = subparser.add_parser(
|
||||
"set_max_concurrency",
|
||||
help="set `max_concurrency` for a user",
|
||||
description=set_max_concurrency.__doc__,
|
||||
)
|
||||
set_max_concurrency_parser.add_argument("user_id")
|
||||
set_max_concurrency_parser.add_argument(
|
||||
"max_concurrency", help="the `max_concurrency` value to set"
|
||||
)
|
||||
set_max_concurrency_parser.set_defaults(func=set_max_concurrency)
|
||||
|
||||
setflag_parser = subparser.add_parser(
|
||||
"setflag", help="set a flag for a user", description=set_flag.__doc__
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue