add test for max_concurrency

see https://gitlab.com/litecord/litecord/-/merge_requests/82#note_728950904
This commit is contained in:
Luna 2022-08-13 19:35:51 -03:00
parent 992b2cbf1f
commit 6024eee19b
4 changed files with 59 additions and 7 deletions

View File

@ -22,7 +22,7 @@ import asyncio
import pprint
import zlib
import time
from typing import List, Dict, Any, Iterable, Optional
from typing import List, Dict, Any, Iterable, Optional, Union
from random import randint
import websockets
@ -353,8 +353,8 @@ class GatewayWebsocket:
"""Send a packet but just the OP code information is filled in."""
await self.send({"op": op_code, "d": data, "t": None, "s": None})
def _check_ratelimit(self, key: str, ratelimit_key):
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
def _check_ratelimit(self, key: str, ratelimit_key: Any):
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}", exact=True)
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
@ -758,7 +758,8 @@ class GatewayWebsocket:
await self.send_op(OP.HEARTBEAT_ACK, None)
async def _connect_ratelimit(self, user_id: int):
async def _connect_ratelimit(self, user_id: Union[int, str]):
log.debug("validating connect ratelimit against {!r}", user_id)
if self._check_ratelimit("connect", user_id):
await self.invalidate_session(False)
raise WebsocketClose(4009, "You are being ratelimited.")

View File

@ -75,6 +75,13 @@ class RatelimitManager:
self._ratelimiters[path] = rtl
def get_ratelimit(self, key: str) -> Ratelimit:
def get_ratelimit(self, key: str, exact=False) -> Ratelimit:
"""Get the :class:`Ratelimit` instance for a given path."""
return self._ratelimiters.get(key, self.global_bucket)
bucket = self._ratelimiters.get(key)
if bucket:
return bucket
if not exact:
return self.global_bucket
else:
raise AssertionError(f"unknown ratelimit bucket '{key}'")

2
run.py
View File

@ -415,7 +415,7 @@ async def app_after_serving():
# first close all clients, then close db
tasks = app.state_manager.gen_close_tasks()
if tasks:
await asyncio.wait(tasks, loop=app.loop)
await asyncio.gather(*tasks)
app.state_manager.close()

View File

@ -39,6 +39,7 @@ from wsproto.events import (
from litecord.gateway.opcodes import OP
from litecord.gateway.websocket import decode_etf
from litecord.ratelimits.bucket import Ratelimit
# Z_SYNC_FLUSH suffix
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
@ -487,3 +488,46 @@ async def test_ready_bot_zlib_stream(test_cli_bot):
await extract_and_verify_ready(conn, zlib_stream=True)
finally:
await _close(conn)
@pytest.mark.asyncio
async def test_max_concurrency(test_cli_bot):
session_ratelimiter = test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"]
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = Ratelimit(1, 5)
try:
gateway_start_coroutines = [gw_start(test_cli_bot.cli) for _ in range(5)]
connections = await asyncio.gather(*gateway_start_coroutines)
# read all HELLOs we send by default
await asyncio.gather(*[_json(conn) for conn in connections])
# make everyone IDENTIFY
await asyncio.gather(
*[
_json_send(
conn,
{"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}},
)
for conn in connections
]
)
await asyncio.sleep(0.5)
# only one of them gets a READY
success_count, error_count = 0, 0
for conn in connections:
try:
data = await _json(conn)
if data["op"] == OP.DISPATCH:
success_count += 1
else:
error_count += 1
finally:
await _close(conn)
assert success_count == 1
assert error_count == 4
finally:
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = session_ratelimiter