diff --git a/litecord/blueprints/gateway.py b/litecord/blueprints/gateway.py index bb90d42..95ce132 100644 --- a/litecord/blueprints/gateway.py +++ b/litecord/blueprints/gateway.py @@ -51,6 +51,14 @@ async def api_gateway_bot(): user_id, ) + max_concurrency = await app.db.fetchval( + """select max_concurrency + from users + where id = $1 + """, + user_id, + ) + shards = max(int(guild_count / 1000), 1) # get _ws.session ratelimit @@ -78,7 +86,7 @@ async def api_gateway_bot(): "total": bucket.requests, "remaining": bucket._tokens, "reset_after": int(reset_after_ts * 1000), - "max_concurrency": 1, + "max_concurrency": max_concurrency, }, } ) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index c306591..3190024 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -811,7 +811,15 @@ class GatewayWebsocket: except (Unauthorized, Forbidden): raise WebsocketClose(4004, "Authentication failed") - await self._connect_ratelimit(user_id) + max_concurrency = await self.app.db.fetchval( + """select max_concurrency + from users + where id = $1 + """, + user_id, + ) + + await self._connect_ratelimit(f"{str(user_id)}%{shard[0]%max_concurrency}") bot = await self.app.db.fetchval( """ diff --git a/manage/cmd/migration/scripts/14_max_concurrency.sql b/manage/cmd/migration/scripts/14_max_concurrency.sql new file mode 100644 index 0000000..3366bc9 --- /dev/null +++ b/manage/cmd/migration/scripts/14_max_concurrency.sql @@ -0,0 +1,3 @@ +alter table users + add column max_concurrency int not null default 1 + check(bot = true or max_concurrency = 1); \ No newline at end of file