diff --git a/litecord/blueprints/gateway.py b/litecord/blueprints/gateway.py index 5453cef..e6e0d2f 100644 --- a/litecord/blueprints/gateway.py +++ b/litecord/blueprints/gateway.py @@ -51,6 +51,14 @@ async def api_gateway_bot(): user_id, ) + max_concurrency = await app.db.fetchval( + """select max_concurrency + from users + where id = $1 + """, + user_id, + ) + shards = max(int(guild_count / 1000), 1) # get _ws.session ratelimit @@ -78,7 +86,7 @@ async def api_gateway_bot(): "total": bucket.requests, "remaining": bucket._tokens, "reset_after": int(reset_after_ts * 1000), - "max_concurrency": 1, + "max_concurrency": max_concurrency, }, } ) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 5e86c99..3f12d06 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -791,7 +791,15 @@ class GatewayWebsocket: except (Unauthorized, Forbidden): raise WebsocketClose(4004, "Authentication failed") - await self._connect_ratelimit(user_id) + max_concurrency = await self.app.db.fetchval( + """select max_concurrency + from users + where id = $1 + """, + user_id, + ) + + await self._connect_ratelimit(f"{str(user_id)}%{shard[0]%max_concurrency}") bot = await self.app.db.fetchval( """ diff --git a/manage/cmd/migration/scripts/14_max_concurrency.sql b/manage/cmd/migration/scripts/14_max_concurrency.sql new file mode 100644 index 0000000..3366bc9 --- /dev/null +++ b/manage/cmd/migration/scripts/14_max_concurrency.sql @@ -0,0 +1,3 @@ +alter table users + add column max_concurrency int not null default 1 + check(bot = true or max_concurrency = 1); \ No newline at end of file diff --git a/manage/cmd/users.py b/manage/cmd/users.py index d1ac694..87d3fdd 100644 --- a/manage/cmd/users.py +++ b/manage/cmd/users.py @@ -94,6 +94,42 @@ async def adduser(ctx, args): print(f'\tdiscrim: {user["discriminator"]}') +async def set_max_concurrency(ctx, args): + """Update the `max_concurrency` for a bot. + This can only be set for bot accounts! + """ + + if int(args.max_concurrency) < 1: + return print("max_concurrency must be >0") + + bot = await ctx.db.fetchval( + """ + select bot + from users + where id = $1 + """, + int(args.user_id), + ) + + if bot == None: + return print("user not found") + + if bot == False: + return print("user must be a bot") + + await ctx.db.execute( + """ + update users + set max_concurrency = $1 + where id = $2 + """, + int(args.max_concurrency), + int(args.user_id), + ) + + print(f"OK: set max_concurrency={args.max_concurrency} for {args.user_id}") + + async def addbot(ctx, args): uid, _ = await create_user(args.username, args.email, args.password) @@ -216,6 +252,17 @@ def setup(subparser): setup_test_parser.set_defaults(func=adduser) + set_max_concurrency_parser = subparser.add_parser( + "set_max_concurrency", + help="set `max_concurrency` for a user", + description=set_max_concurrency.__doc__, + ) + set_max_concurrency_parser.add_argument("user_id") + set_max_concurrency_parser.add_argument( + "max_concurrency", help="the `max_concurrency` value to set" + ) + set_max_concurrency_parser.set_defaults(func=set_max_concurrency) + addbot_parser = subparser.add_parser("addbot", help="create a bot") addbot_parser.add_argument("username", help="username of the bot")