mirror of https://gitlab.com/litecord/litecord.git
add test for max_concurrency
see https://gitlab.com/litecord/litecord/-/merge_requests/82#note_728950904
This commit is contained in:
parent
992b2cbf1f
commit
6024eee19b
|
|
@ -22,7 +22,7 @@ import asyncio
|
||||||
import pprint
|
import pprint
|
||||||
import zlib
|
import zlib
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Any, Iterable, Optional
|
from typing import List, Dict, Any, Iterable, Optional, Union
|
||||||
from random import randint
|
from random import randint
|
||||||
|
|
||||||
import websockets
|
import websockets
|
||||||
|
|
@ -353,8 +353,8 @@ class GatewayWebsocket:
|
||||||
"""Send a packet but just the OP code information is filled in."""
|
"""Send a packet but just the OP code information is filled in."""
|
||||||
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
await self.send({"op": op_code, "d": data, "t": None, "s": None})
|
||||||
|
|
||||||
def _check_ratelimit(self, key: str, ratelimit_key):
|
def _check_ratelimit(self, key: str, ratelimit_key: Any):
|
||||||
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}")
|
ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}", exact=True)
|
||||||
bucket = ratelimit.get_bucket(ratelimit_key)
|
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||||
return bucket.update_rate_limit()
|
return bucket.update_rate_limit()
|
||||||
|
|
||||||
|
|
@ -758,7 +758,8 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
await self.send_op(OP.HEARTBEAT_ACK, None)
|
await self.send_op(OP.HEARTBEAT_ACK, None)
|
||||||
|
|
||||||
async def _connect_ratelimit(self, user_id: int):
|
async def _connect_ratelimit(self, user_id: Union[int, str]):
|
||||||
|
log.debug("validating connect ratelimit against {!r}", user_id)
|
||||||
if self._check_ratelimit("connect", user_id):
|
if self._check_ratelimit("connect", user_id):
|
||||||
await self.invalidate_session(False)
|
await self.invalidate_session(False)
|
||||||
raise WebsocketClose(4009, "You are being ratelimited.")
|
raise WebsocketClose(4009, "You are being ratelimited.")
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,13 @@ class RatelimitManager:
|
||||||
|
|
||||||
self._ratelimiters[path] = rtl
|
self._ratelimiters[path] = rtl
|
||||||
|
|
||||||
def get_ratelimit(self, key: str) -> Ratelimit:
|
def get_ratelimit(self, key: str, exact=False) -> Ratelimit:
|
||||||
"""Get the :class:`Ratelimit` instance for a given path."""
|
"""Get the :class:`Ratelimit` instance for a given path."""
|
||||||
return self._ratelimiters.get(key, self.global_bucket)
|
bucket = self._ratelimiters.get(key)
|
||||||
|
if bucket:
|
||||||
|
return bucket
|
||||||
|
|
||||||
|
if not exact:
|
||||||
|
return self.global_bucket
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"unknown ratelimit bucket '{key}'")
|
||||||
|
|
|
||||||
2
run.py
2
run.py
|
|
@ -415,7 +415,7 @@ async def app_after_serving():
|
||||||
# first close all clients, then close db
|
# first close all clients, then close db
|
||||||
tasks = app.state_manager.gen_close_tasks()
|
tasks = app.state_manager.gen_close_tasks()
|
||||||
if tasks:
|
if tasks:
|
||||||
await asyncio.wait(tasks, loop=app.loop)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
app.state_manager.close()
|
app.state_manager.close()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -39,6 +39,7 @@ from wsproto.events import (
|
||||||
|
|
||||||
from litecord.gateway.opcodes import OP
|
from litecord.gateway.opcodes import OP
|
||||||
from litecord.gateway.websocket import decode_etf
|
from litecord.gateway.websocket import decode_etf
|
||||||
|
from litecord.ratelimits.bucket import Ratelimit
|
||||||
|
|
||||||
# Z_SYNC_FLUSH suffix
|
# Z_SYNC_FLUSH suffix
|
||||||
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
ZLIB_SUFFIX = b"\x00\x00\xff\xff"
|
||||||
|
|
@ -487,3 +488,46 @@ async def test_ready_bot_zlib_stream(test_cli_bot):
|
||||||
await extract_and_verify_ready(conn, zlib_stream=True)
|
await extract_and_verify_ready(conn, zlib_stream=True)
|
||||||
finally:
|
finally:
|
||||||
await _close(conn)
|
await _close(conn)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_concurrency(test_cli_bot):
|
||||||
|
session_ratelimiter = test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"]
|
||||||
|
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = Ratelimit(1, 5)
|
||||||
|
|
||||||
|
try:
|
||||||
|
|
||||||
|
gateway_start_coroutines = [gw_start(test_cli_bot.cli) for _ in range(5)]
|
||||||
|
connections = await asyncio.gather(*gateway_start_coroutines)
|
||||||
|
|
||||||
|
# read all HELLOs we send by default
|
||||||
|
await asyncio.gather(*[_json(conn) for conn in connections])
|
||||||
|
|
||||||
|
# make everyone IDENTIFY
|
||||||
|
await asyncio.gather(
|
||||||
|
*[
|
||||||
|
_json_send(
|
||||||
|
conn,
|
||||||
|
{"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}},
|
||||||
|
)
|
||||||
|
for conn in connections
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# only one of them gets a READY
|
||||||
|
success_count, error_count = 0, 0
|
||||||
|
for conn in connections:
|
||||||
|
try:
|
||||||
|
data = await _json(conn)
|
||||||
|
if data["op"] == OP.DISPATCH:
|
||||||
|
success_count += 1
|
||||||
|
else:
|
||||||
|
error_count += 1
|
||||||
|
finally:
|
||||||
|
await _close(conn)
|
||||||
|
assert success_count == 1
|
||||||
|
assert error_count == 4
|
||||||
|
|
||||||
|
finally:
|
||||||
|
test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = session_ratelimiter
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue