From 33f893c0ff5c92a5a037d45bd517016c65f127e5 Mon Sep 17 00:00:00 2001 From: Luna Mendes Date: Sun, 4 Nov 2018 02:06:40 -0300 Subject: [PATCH] all: add ratelimit implementation haven't tested yet, but it should work in theory. - gateway.websocket: add the 3 main ws ratelimits - litecord: add ratelimits package - ratelimits.main: add implementation - run: add app_set_ratelimit_headers --- litecord/errors.py | 4 ++ litecord/gateway/websocket.py | 30 ++++++++- litecord/ratelimits/bucket.py | 113 +++++++++++++++++++++++++++++++++ litecord/ratelimits/handler.py | 67 +++++++++++++++++++ litecord/ratelimits/main.py | 56 ++++++++++++++-- run.py | 26 +++++++- 6 files changed, 288 insertions(+), 8 deletions(-) create mode 100644 litecord/ratelimits/bucket.py create mode 100644 litecord/ratelimits/handler.py diff --git a/litecord/errors.py b/litecord/errors.py index fe4f130..e1fe1ac 100644 --- a/litecord/errors.py +++ b/litecord/errors.py @@ -41,6 +41,10 @@ class MessageNotFound(LitecordError): status_code = 404 +class Ratelimited(LitecordError): + status_code = 429 + + class WebsocketClose(Exception): @property def code(self): diff --git a/litecord/gateway/websocket.py b/litecord/gateway/websocket.py index 0d54ad3..e03e12c 100644 --- a/litecord/gateway/websocket.py +++ b/litecord/gateway/websocket.py @@ -28,7 +28,8 @@ WebsocketProperties = collections.namedtuple( ) WebsocketObjects = collections.namedtuple( - 'WebsocketObjects', 'db state_manager storage loop dispatcher presence' + 'WebsocketObjects', ('db', 'state_manager', 'storage', + 'loop', 'dispatcher', 'presence', 'ratelimiter') ) @@ -138,6 +139,11 @@ class GatewayWebsocket: else: await self.ws.send(encoded.decode()) + def _check_ratelimit(self, key: str, ratelimit_key: str): + ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}') + bucket = ratelimit.get_bucket(ratelimit_key) + return bucket.update_rate_limit() + async def _hb_wait(self, interval: int): """Wait heartbeat""" # if the client heartbeats in time, @@ -342,6 +348,14 @@ class GatewayWebsocket: async def update_status(self, status: dict): """Update the status of the current websocket connection.""" + if not self.state: + return + + if self._check_ratelimit('presence', self.state.session_id): + # Presence Updates beyond the ratelimit + # are just silently dropped. + return + if status is None: status = { 'afk': False, @@ -395,6 +409,11 @@ class GatewayWebsocket: 'op': OP.HEARTBEAT_ACK, }) + async def _connect_ratelimit(self, user_id: int): + if self._check_ratelimit('connect', user_id): + await self.invalidate_session(False) + raise WebsocketClose(4009, 'You are being ratelimited.') + async def handle_2(self, payload: Dict[str, Any]): """Handle the OP 2 Identify packet.""" try: @@ -414,6 +433,8 @@ class GatewayWebsocket: except (Unauthorized, Forbidden): raise WebsocketClose(4004, 'Authentication failed') + await self._connect_ratelimit(user_id) + bot = await self.ext.db.fetchval(""" SELECT bot FROM users WHERE id = $1 @@ -751,6 +772,10 @@ class GatewayWebsocket: await handler(payload) + async def _msg_ratelimit(self): + if self._check_ratelimit('messages', self.state.session_id): + raise WebsocketClose(4008, 'You are being ratelimited.') + async def listen_messages(self): """Listen for messages coming in from the websocket.""" @@ -767,6 +792,9 @@ class GatewayWebsocket: if len(message) > 4096: raise DecodeError('Payload length exceeded') + if self.state: + await self._msg_ratelimit() + payload = self.decoder(message) await self.process_message(payload) diff --git a/litecord/ratelimits/bucket.py b/litecord/ratelimits/bucket.py new file mode 100644 index 0000000..dabb0ae --- /dev/null +++ b/litecord/ratelimits/bucket.py @@ -0,0 +1,113 @@ +""" +main litecord ratelimiting code + + This code was copied from elixire's ratelimiting, + which in turn is a work on top of discord.py's ratelimiting. +""" +import time + + +class RatelimitBucket: + """Main ratelimit bucket class.""" + def __init__(self, tokens, second): + self.requests = tokens + self.second = second + + self._window = 0.0 + self._tokens = self.requests + self.retries = 0 + self._last = 0.0 + + def get_tokens(self, current): + """Get the current amount of available tokens.""" + if not current: + current = time.time() + + # by default, use _tokens + tokens = self._tokens + + # if current timestamp is above _window + seconds + # reset tokens to self.requests (default) + if current > self._window + self.second: + tokens = self.requests + + return tokens + + def update_rate_limit(self): + """Update current ratelimit state.""" + current = time.time() + self._last = current + self._tokens = self.get_tokens(current) + + # we are using the ratelimit for the first time + # so set current ratelimit window to right now + if self._tokens == self.requests: + self._window = current + + # Are we currently ratelimited? + if self._tokens == 0: + self.retries += 1 + return self.second - (current - self._window) + + # if not ratelimited, remove a token + self.retries = 0 + self._tokens -= 1 + + # if we got ratelimited after that token removal, + # set window to now + if self._tokens == 0: + self._window = current + + def reset(self): + """Reset current ratelimit to default state.""" + self._tokens = self.requests + self._last = 0.0 + self.retries = 0 + + def copy(self): + """Create a copy of this ratelimit. + + Used to manage multiple ratelimits to users. + """ + return RatelimitBucket(self.requests, + self.second) + + def __repr__(self): + return (f'') + + +class Ratelimit: + """Manages buckets.""" + def __init__(self, tokens, second, keys=None): + self._cache = {} + if keys is None: + keys = tuple() + self.keys = keys + self._cooldown = RatelimitBucket(tokens, second) + + def __repr__(self): + return (f'') + + def _verify_cache(self): + current = time.time() + dead_keys = [k for k, v in self._cache.items() + if current > v._last + v.second] + + for k in dead_keys: + del self._cache[k] + + def get_bucket(self, key) -> RatelimitBucket: + if not self._cooldown: + return None + + self._verify_cache() + + if key not in self._cache: + bucket = self._cooldown.copy() + self._cache[key] = bucket + else: + bucket = self._cache[key] + + return bucket diff --git a/litecord/ratelimits/handler.py b/litecord/ratelimits/handler.py new file mode 100644 index 0000000..db896bf --- /dev/null +++ b/litecord/ratelimits/handler.py @@ -0,0 +1,67 @@ +from quart import current_app as app, request, g + +from litecord.errors import Ratelimited +from litecord.auth import token_check, Unauthorized + + +async def _check_bucket(bucket): + retry_after = bucket.update_rate_limit() + + request.bucket = bucket + + if retry_after: + raise Ratelimited('You are being ratelimited.', { + 'retry_after': retry_after + }) + + +async def _handle_global(ratelimit): + """Global ratelimit is per-user.""" + try: + user_id = await token_check() + except Unauthorized: + user_id = request.remote_addr + + bucket = ratelimit.get_bucket(user_id) + await _check_bucket(bucket) + + +async def _handle_specific(ratelimit): + try: + user_id = await token_check() + except Unauthorized: + user_id = request.remote_addr + + # construct the key based on the ratelimit.keys + keys = ratelimit.keys + + # base key is the user id + key_components = [f'user_id:{user_id}'] + + for key in keys: + val = request.view_args[key] + key_components.append(f'{key}:{val}') + + bucket_key = ':'.join(key_components) + bucket = ratelimit.get_bucket(bucket_key) + await _check_bucket(bucket) + + +async def ratelimit_handler(): + """Main ratelimit handler. + + Decides on which ratelimit to use. + """ + rule = request.url_rule + + # rule.endpoint is composed of '.' + # and so we can use that to make routes with different + # methods have different ratelimits + rule_path = rule.endpoint + + try: + ratelimit = app.ratelimiter.get_ratelimit(rule_path) + await _handle_specific(ratelimit) + except KeyError: + ratelimit = app.ratelimiter.global_bucket + await _handle_global(ratelimit) diff --git a/litecord/ratelimits/main.py b/litecord/ratelimits/main.py index 4ab71d6..ffcc54d 100644 --- a/litecord/ratelimits/main.py +++ b/litecord/ratelimits/main.py @@ -1,5 +1,53 @@ -from quart import current_app as app, request +from litecord.ratelimits.bucket import Ratelimit -async def ratelimit_handler(): - # dummy handler for future code - print(request.headers) +""" +REST: + POST Message | 5/5s | per-channel + DELETE Message | 5/1s | per-channel + PUT/DELETE Reaction | 1/0.25s | per-channel + PATCH Member | 10/10s | per-guild + PATCH Member Nick | 1/1s | per-guild + PATCH Username | 2/3600s | per-account + |All Requests| | 50/1s | per-account +WS: + Gateway Connect | 1/5s | per-account + Presence Update | 5/60s | per-session + |All Sent Messages| | 120/60s | per-session +""" + +REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id')) + +RATELIMITS = { + 'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')), + 'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')), + + # all of those share the same bucket. + 'channel_reactions.add_reaction': REACTION_BUCKET, + 'channel_reactions.remove_own_reaction': REACTION_BUCKET, + 'channel_reactions.remove_user_reaction': REACTION_BUCKET, + + 'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')), + 'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')), + + # this only applies to username. + # 'users.patch_me': Ratelimit(2, 3600), + + '_ws.connect': Ratelimit(1, 5), + '_ws.presence': Ratelimit(5, 60), + '_ws.messages': Ratelimit(120, 60), +} + +class RatelimitManager: + """Manager for the bucket managers""" + def __init__(self): + self._ratelimiters = {} + self.global_bucket = Ratelimit(50, 1) + self._fill_rtl() + + def _fill_rtl(self): + for path, rtl in RATELIMITS.items(): + self._ratelimiters[path] = rtl + + def get_ratelimit(self, key: str) -> Ratelimit: + """Get the :class:`Ratelimit` instance for a given path.""" + return self._ratelimiters.get(key, self.global_bucket) diff --git a/run.py b/run.py index 195aa0c..ff3e623 100644 --- a/run.py +++ b/run.py @@ -9,14 +9,14 @@ from quart import Quart, g, jsonify, request from logbook import StreamHandler, Logger from logbook.compat import redirect_logging +# import the config set by instance owner import config + from litecord.blueprints import ( gateway, auth, users, guilds, channels, webhooks, science, voice, invites, relationships, dms ) -from litecord.ratelimits.main import ratelimit_handler - # those blueprints are separated from the "main" ones # for code readability if people want to dig through # the codebase. @@ -28,6 +28,9 @@ from litecord.blueprints.channel import ( channel_messages, channel_reactions ) +from litecord.ratelimits.handler import ratelimit_handler +from litecord.ratelimits.main import RatelimitManager + from litecord.gateway import websocket_handler from litecord.errors import LitecordError from litecord.gateway.state_manager import StateManager @@ -110,6 +113,21 @@ async def app_after_request(resp): # resp.headers['Access-Control-Allow-Methods'] = '*' resp.headers['Access-Control-Allow-Methods'] = \ resp.headers.get('allow', '*') + + return resp + + +@app.after_request +async def app_set_ratelimit_headers(resp): + """Set the specific ratelimit headers.""" + try: + bucket = request.bucket + resp.headers['X-RateLimit-Limit'] = str(bucket.requests) + resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens) + resp.headers['X-RateLimit-Reset'] = str(bucket._window + bucket.second) + except AttributeError: + pass + return resp @@ -123,6 +141,7 @@ async def app_before_serving(): app.loop = asyncio.get_event_loop() g.loop = asyncio.get_event_loop() + app.ratelimiter = RatelimitManager() app.state_manager = StateManager() app.storage = Storage(app.db) @@ -141,7 +160,8 @@ async def app_before_serving(): # TODO: pass just the app object await websocket_handler((app.db, app.state_manager, app.storage, - app.loop, app.dispatcher, app.presence), + app.loop, app.dispatcher, app.presence, + app.ratelimiter), ws, url) ws_future = websockets.serve(_wrapper, host, port)