mirror of https://gitlab.com/litecord/litecord.git
add test for max_concurrency
see https://gitlab.com/litecord/litecord/-/merge_requests/82#note_728950904
This commit is contained in:
parent
992b2cbf1f
commit
6024eee19b
|
|
@ -22,7 +22,7 @@ import asyncio
|
|||
import pprint
|
||||
import zlib
|
||||
import time
|
||||
from typing import List, Dict, Any, Iterable, Optional
|
||||
from typing import List, Dict, Any, Iterable, Optional, Union
|
||||
from random import randint
|
||||
|
||||
import websockets
|
||||
|
|
@ -353,8 +353,8 @@ class GatewayWebsocket:
|
|||
"""Send a packet but just the OP code information is filled in."""
|
||||
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
||||
|
||||
def _check_ratelimit(self, key: str, ratelimit_key):
|
||||
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
|
||||
def _check_ratelimit(self, key: str, ratelimit_key: Any):
|
||||
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}", exact=True)
|
||||
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||
return bucket.update_rate_limit()
|
||||
|
||||
|
|
@ -758,7 +758,8 @@ class GatewayWebsocket:
|
|||
|
||||
await self.send_op(OP.HEARTBEAT_ACK, None)
|
||||
|
||||
async def _connect_ratelimit(self, user_id: int):
|
||||
async def _connect_ratelimit(self, user_id: Union[int, str]):
|
||||
log.debug("validating connect ratelimit against {!r}", user_id)
|
||||
if self._check_ratelimit("connect", user_id):
|
||||
await self.invalidate_session(False)
|
||||
raise WebsocketClose(4009, "You are being ratelimited.")
|
||||
|
|
|
|||
|
|
@ -75,6 +75,13 @@ class RatelimitManager:
|
|||
|
||||
self._ratelimiters[path] = rtl
|
||||
|
||||
def get_ratelimit(self, key: str) -> Ratelimit:
|
||||
def get_ratelimit(self, key: str, exact=False) -> Ratelimit:
|
||||
"""Get the :class:`Ratelimit` instance for a given path."""
|
||||
return self._ratelimiters.get(key, self.global_bucket)
|
||||
bucket = self._ratelimiters.get(key)
|
||||
if bucket:
|
||||
return bucket
|
||||
|
||||
if not exact:
|
||||
return self.global_bucket
|
||||
else:
|
||||
raise AssertionError(f"unknown ratelimit bucket '{key}'")
|
||||
|
|
|
|||
2
run.py
2
run.py
|
|
@ -415,7 +415,7 @@ async def app_after_serving():
|
|||
# first close all clients, then close db
|
||||
tasks = app.state_manager.gen_close_tasks()
|
||||
if tasks:
|
||||
await asyncio.wait(tasks, loop=app.loop)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
app.state_manager.close()
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ from wsproto.events import (
|
|||
|
||||
from litecord.gateway.opcodes import OP
|
||||
from litecord.gateway.websocket import decode_etf
|
||||
from litecord.ratelimits.bucket import Ratelimit
|
||||
|
||||
# Z_SYNC_FLUSH suffix
|
||||
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
||||
|
|
@ -487,3 +488,46 @@ async def test_ready_bot_zlib_stream(test_cli_bot):
|
|||
await extract_and_verify_ready(conn, zlib_stream=True)
|
||||
finally:
|
||||
await _close(conn)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_concurrency(test_cli_bot):
|
||||
session_ratelimiter = test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"]
|
||||
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = Ratelimit(1, 5)
|
||||
|
||||
try:
|
||||
|
||||
gateway_start_coroutines = [gw_start(test_cli_bot.cli) for _ in range(5)]
|
||||
connections = await asyncio.gather(*gateway_start_coroutines)
|
||||
|
||||
# read all HELLOs we send by default
|
||||
await asyncio.gather(*[_json(conn) for conn in connections])
|
||||
|
||||
# make everyone IDENTIFY
|
||||
await asyncio.gather(
|
||||
*[
|
||||
_json_send(
|
||||
conn,
|
||||
{"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}},
|
||||
)
|
||||
for conn in connections
|
||||
]
|
||||
)
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# only one of them gets a READY
|
||||
success_count, error_count = 0, 0
|
||||
for conn in connections:
|
||||
try:
|
||||
data = await _json(conn)
|
||||
if data["op"] == OP.DISPATCH:
|
||||
success_count += 1
|
||||
else:
|
||||
error_count += 1
|
||||
finally:
|
||||
await _close(conn)
|
||||
assert success_count == 1
|
||||
assert error_count == 4
|
||||
|
||||
finally:
|
||||
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = session_ratelimiter
|
||||
|
|
|
|||
Loading…
Reference in New Issue