From 6024eee19b13672c7c1f6a18a8ac7fc6f02f8b76 Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 13 Aug 2022 19:35:51 -0300 Subject: [PATCH] add test for max_concurrency see https://gitlab.com/litecord/litecord/-/merge_requests/82#note_728950904 --- litecord/gateway/websocket.py | 9 +++---- litecord/ratelimits/main.py | 11 +++++++-- run.py | 2 +- tests/test_websocket.py | 44 +++++++++++++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 7 deletions(-) diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 3f12d06..c755185 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -22,7 +22,7 @@ import asyncio import pprint import zlib import time -from typing import List, Dict, Any, Iterable, Optional +from typing import List, Dict, Any, Iterable, Optional, Union from random import randint import websockets @@ -353,8 +353,8 @@ class GatewayWebsocket: """Send a packet but just the OP code information is filled in.""" await self.send({"op": op_code, "d": data, "t": None, "s": None}) - def _check_ratelimit(self, key: str, ratelimit_key): - ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}") + def _check_ratelimit(self, key: str, ratelimit_key: Any): + ratelimit = self.app.ratelimiter.get_ratelimit(f"_ws.{key}", exact=True) bucket = ratelimit.get_bucket(ratelimit_key) return bucket.update_rate_limit() @@ -758,7 +758,8 @@ class GatewayWebsocket: await self.send_op(OP.HEARTBEAT_ACK, None) - async def _connect_ratelimit(self, user_id: int): + async def _connect_ratelimit(self, user_id: Union[int, str]): + log.debug("validating connect ratelimit against {!r}", user_id) if self._check_ratelimit("connect", user_id): await self.invalidate_session(False) raise WebsocketClose(4009, "You are being ratelimited.") diff --git a/litecord/ratelimits/main.py b/litecord/ratelimits/main.py index 6ede010..0c3ff95 100644 --- a/litecord/ratelimits/main.py +++ b/litecord/ratelimits/main.py @@ -75,6 +75,13 @@ class RatelimitManager: self._ratelimiters[path] = rtl - def get_ratelimit(self, key: str) -> Ratelimit: + def get_ratelimit(self, key: str, exact=False) -> Ratelimit: """Get the :class:`Ratelimit` instance for a given path.""" - return self._ratelimiters.get(key, self.global_bucket) + bucket = self._ratelimiters.get(key) + if bucket: + return bucket + + if not exact: + return self.global_bucket + else: + raise AssertionError(f"unknown ratelimit bucket '{key}'") diff --git a/run.py b/run.py index a966b48..0c80c06 100644 --- a/run.py +++ b/run.py @@ -415,7 +415,7 @@ async def app_after_serving(): # first close all clients, then close db tasks = app.state_manager.gen_close_tasks() if tasks: - await asyncio.wait(tasks, loop=app.loop) + await asyncio.gather(*tasks) app.state_manager.close() diff --git a/tests/test_websocket.py b/tests/test_websocket.py index ee8d3fe..e8f7f00 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -39,6 +39,7 @@ from wsproto.events import ( from litecord.gateway.opcodes import OP from litecord.gateway.websocket import decode_etf +from litecord.ratelimits.bucket import Ratelimit # Z_SYNC_FLUSH suffix ZLIB_SUFFIX = b"\x00\x00\xff\xff" @@ -487,3 +488,46 @@ async def test_ready_bot_zlib_stream(test_cli_bot): await extract_and_verify_ready(conn, zlib_stream=True) finally: await _close(conn) + + +@pytest.mark.asyncio +async def test_max_concurrency(test_cli_bot): + session_ratelimiter = test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] + test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = Ratelimit(1, 5) + + try: + + gateway_start_coroutines = [gw_start(test_cli_bot.cli) for _ in range(5)] + connections = await asyncio.gather(*gateway_start_coroutines) + + # read all HELLOs we send by default + await asyncio.gather(*[_json(conn) for conn in connections]) + + # make everyone IDENTIFY + await asyncio.gather( + *[ + _json_send( + conn, + {"op": OP.IDENTIFY, "d": {"token": test_cli_bot.user["token"]}}, + ) + for conn in connections + ] + ) + await asyncio.sleep(0.5) + + # only one of them gets a READY + success_count, error_count = 0, 0 + for conn in connections: + try: + data = await _json(conn) + if data["op"] == OP.DISPATCH: + success_count += 1 + else: + error_count += 1 + finally: + await _close(conn) + assert success_count == 1 + assert error_count == 4 + + finally: + test_cli_bot.app.ratelimiter._ratelimiters["_ws.connect"] = session_ratelimiter