mirror of https://gitlab.com/litecord/litecord.git
all: add ratelimit implementation
haven't tested yet, but it should work in theory. - gateway.websocket: add the 3 main ws ratelimits - litecord: add ratelimits package - ratelimits.main: add implementation - run: add app_set_ratelimit_headers
This commit is contained in:
parent
b17cfd46eb
commit
33f893c0ff
|
|
@ -41,6 +41,10 @@ class MessageNotFound(LitecordError):
|
|||
status_code = 404
|
||||
|
||||
|
||||
class Ratelimited(LitecordError):
|
||||
status_code = 429
|
||||
|
||||
|
||||
class WebsocketClose(Exception):
|
||||
@property
|
||||
def code(self):
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ WebsocketProperties = collections.namedtuple(
|
|||
)
|
||||
|
||||
WebsocketObjects = collections.namedtuple(
|
||||
'WebsocketObjects', 'db state_manager storage loop dispatcher presence'
|
||||
'WebsocketObjects', ('db', 'state_manager', 'storage',
|
||||
'loop', 'dispatcher', 'presence', 'ratelimiter')
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -138,6 +139,11 @@ class GatewayWebsocket:
|
|||
else:
|
||||
await self.ws.send(encoded.decode())
|
||||
|
||||
def _check_ratelimit(self, key: str, ratelimit_key: str):
|
||||
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
|
||||
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||
return bucket.update_rate_limit()
|
||||
|
||||
async def _hb_wait(self, interval: int):
|
||||
"""Wait heartbeat"""
|
||||
# if the client heartbeats in time,
|
||||
|
|
@ -342,6 +348,14 @@ class GatewayWebsocket:
|
|||
|
||||
async def update_status(self, status: dict):
|
||||
"""Update the status of the current websocket connection."""
|
||||
if not self.state:
|
||||
return
|
||||
|
||||
if self._check_ratelimit('presence', self.state.session_id):
|
||||
# Presence Updates beyond the ratelimit
|
||||
# are just silently dropped.
|
||||
return
|
||||
|
||||
if status is None:
|
||||
status = {
|
||||
'afk': False,
|
||||
|
|
@ -395,6 +409,11 @@ class GatewayWebsocket:
|
|||
'op': OP.HEARTBEAT_ACK,
|
||||
})
|
||||
|
||||
async def _connect_ratelimit(self, user_id: int):
|
||||
if self._check_ratelimit('connect', user_id):
|
||||
await self.invalidate_session(False)
|
||||
raise WebsocketClose(4009, 'You are being ratelimited.')
|
||||
|
||||
async def handle_2(self, payload: Dict[str, Any]):
|
||||
"""Handle the OP 2 Identify packet."""
|
||||
try:
|
||||
|
|
@ -414,6 +433,8 @@ class GatewayWebsocket:
|
|||
except (Unauthorized, Forbidden):
|
||||
raise WebsocketClose(4004, 'Authentication failed')
|
||||
|
||||
await self._connect_ratelimit(user_id)
|
||||
|
||||
bot = await self.ext.db.fetchval("""
|
||||
SELECT bot FROM users
|
||||
WHERE id = $1
|
||||
|
|
@ -751,6 +772,10 @@ class GatewayWebsocket:
|
|||
|
||||
await handler(payload)
|
||||
|
||||
async def _msg_ratelimit(self):
|
||||
if self._check_ratelimit('messages', self.state.session_id):
|
||||
raise WebsocketClose(4008, 'You are being ratelimited.')
|
||||
|
||||
async def listen_messages(self):
|
||||
"""Listen for messages coming in from the websocket."""
|
||||
|
||||
|
|
@ -767,6 +792,9 @@ class GatewayWebsocket:
|
|||
if len(message) > 4096:
|
||||
raise DecodeError('Payload length exceeded')
|
||||
|
||||
if self.state:
|
||||
await self._msg_ratelimit()
|
||||
|
||||
payload = self.decoder(message)
|
||||
await self.process_message(payload)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
"""
|
||||
main litecord ratelimiting code
|
||||
|
||||
This code was copied from elixire's ratelimiting,
|
||||
which in turn is a work on top of discord.py's ratelimiting.
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
class RatelimitBucket:
|
||||
"""Main ratelimit bucket class."""
|
||||
def __init__(self, tokens, second):
|
||||
self.requests = tokens
|
||||
self.second = second
|
||||
|
||||
self._window = 0.0
|
||||
self._tokens = self.requests
|
||||
self.retries = 0
|
||||
self._last = 0.0
|
||||
|
||||
def get_tokens(self, current):
|
||||
"""Get the current amount of available tokens."""
|
||||
if not current:
|
||||
current = time.time()
|
||||
|
||||
# by default, use _tokens
|
||||
tokens = self._tokens
|
||||
|
||||
# if current timestamp is above _window + seconds
|
||||
# reset tokens to self.requests (default)
|
||||
if current > self._window + self.second:
|
||||
tokens = self.requests
|
||||
|
||||
return tokens
|
||||
|
||||
def update_rate_limit(self):
|
||||
"""Update current ratelimit state."""
|
||||
current = time.time()
|
||||
self._last = current
|
||||
self._tokens = self.get_tokens(current)
|
||||
|
||||
# we are using the ratelimit for the first time
|
||||
# so set current ratelimit window to right now
|
||||
if self._tokens == self.requests:
|
||||
self._window = current
|
||||
|
||||
# Are we currently ratelimited?
|
||||
if self._tokens == 0:
|
||||
self.retries += 1
|
||||
return self.second - (current - self._window)
|
||||
|
||||
# if not ratelimited, remove a token
|
||||
self.retries = 0
|
||||
self._tokens -= 1
|
||||
|
||||
# if we got ratelimited after that token removal,
|
||||
# set window to now
|
||||
if self._tokens == 0:
|
||||
self._window = current
|
||||
|
||||
def reset(self):
|
||||
"""Reset current ratelimit to default state."""
|
||||
self._tokens = self.requests
|
||||
self._last = 0.0
|
||||
self.retries = 0
|
||||
|
||||
def copy(self):
|
||||
"""Create a copy of this ratelimit.
|
||||
|
||||
Used to manage multiple ratelimits to users.
|
||||
"""
|
||||
return RatelimitBucket(self.requests,
|
||||
self.second)
|
||||
|
||||
def __repr__(self):
|
||||
return (f'<RatelimitBucket requests={self.requests} '
|
||||
f'second={self.second} window: {self._window} '
|
||||
f'tokens={self._tokens}>')
|
||||
|
||||
|
||||
class Ratelimit:
|
||||
"""Manages buckets."""
|
||||
def __init__(self, tokens, second, keys=None):
|
||||
self._cache = {}
|
||||
if keys is None:
|
||||
keys = tuple()
|
||||
self.keys = keys
|
||||
self._cooldown = RatelimitBucket(tokens, second)
|
||||
|
||||
def __repr__(self):
|
||||
return (f'<Ratelimit cooldown={self._cooldown}>')
|
||||
|
||||
def _verify_cache(self):
|
||||
current = time.time()
|
||||
dead_keys = [k for k, v in self._cache.items()
|
||||
if current > v._last + v.second]
|
||||
|
||||
for k in dead_keys:
|
||||
del self._cache[k]
|
||||
|
||||
def get_bucket(self, key) -> RatelimitBucket:
|
||||
if not self._cooldown:
|
||||
return None
|
||||
|
||||
self._verify_cache()
|
||||
|
||||
if key not in self._cache:
|
||||
bucket = self._cooldown.copy()
|
||||
self._cache[key] = bucket
|
||||
else:
|
||||
bucket = self._cache[key]
|
||||
|
||||
return bucket
|
||||
|
|
@ -0,0 +1,67 @@
|
|||
from quart import current_app as app, request, g
|
||||
|
||||
from litecord.errors import Ratelimited
|
||||
from litecord.auth import token_check, Unauthorized
|
||||
|
||||
|
||||
async def _check_bucket(bucket):
|
||||
retry_after = bucket.update_rate_limit()
|
||||
|
||||
request.bucket = bucket
|
||||
|
||||
if retry_after:
|
||||
raise Ratelimited('You are being ratelimited.', {
|
||||
'retry_after': retry_after
|
||||
})
|
||||
|
||||
|
||||
async def _handle_global(ratelimit):
|
||||
"""Global ratelimit is per-user."""
|
||||
try:
|
||||
user_id = await token_check()
|
||||
except Unauthorized:
|
||||
user_id = request.remote_addr
|
||||
|
||||
bucket = ratelimit.get_bucket(user_id)
|
||||
await _check_bucket(bucket)
|
||||
|
||||
|
||||
async def _handle_specific(ratelimit):
|
||||
try:
|
||||
user_id = await token_check()
|
||||
except Unauthorized:
|
||||
user_id = request.remote_addr
|
||||
|
||||
# construct the key based on the ratelimit.keys
|
||||
keys = ratelimit.keys
|
||||
|
||||
# base key is the user id
|
||||
key_components = [f'user_id:{user_id}']
|
||||
|
||||
for key in keys:
|
||||
val = request.view_args[key]
|
||||
key_components.append(f'{key}:{val}')
|
||||
|
||||
bucket_key = ':'.join(key_components)
|
||||
bucket = ratelimit.get_bucket(bucket_key)
|
||||
await _check_bucket(bucket)
|
||||
|
||||
|
||||
async def ratelimit_handler():
|
||||
"""Main ratelimit handler.
|
||||
|
||||
Decides on which ratelimit to use.
|
||||
"""
|
||||
rule = request.url_rule
|
||||
|
||||
# rule.endpoint is composed of '<blueprint>.<function>'
|
||||
# and so we can use that to make routes with different
|
||||
# methods have different ratelimits
|
||||
rule_path = rule.endpoint
|
||||
|
||||
try:
|
||||
ratelimit = app.ratelimiter.get_ratelimit(rule_path)
|
||||
await _handle_specific(ratelimit)
|
||||
except KeyError:
|
||||
ratelimit = app.ratelimiter.global_bucket
|
||||
await _handle_global(ratelimit)
|
||||
|
|
@ -1,5 +1,53 @@
|
|||
from quart import current_app as app, request
|
||||
from litecord.ratelimits.bucket import Ratelimit
|
||||
|
||||
async def ratelimit_handler():
|
||||
# dummy handler for future code
|
||||
print(request.headers)
|
||||
"""
|
||||
REST:
|
||||
POST Message | 5/5s | per-channel
|
||||
DELETE Message | 5/1s | per-channel
|
||||
PUT/DELETE Reaction | 1/0.25s | per-channel
|
||||
PATCH Member | 10/10s | per-guild
|
||||
PATCH Member Nick | 1/1s | per-guild
|
||||
PATCH Username | 2/3600s | per-account
|
||||
|All Requests| | 50/1s | per-account
|
||||
WS:
|
||||
Gateway Connect | 1/5s | per-account
|
||||
Presence Update | 5/60s | per-session
|
||||
|All Sent Messages| | 120/60s | per-session
|
||||
"""
|
||||
|
||||
REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id'))
|
||||
|
||||
RATELIMITS = {
|
||||
'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')),
|
||||
'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')),
|
||||
|
||||
# all of those share the same bucket.
|
||||
'channel_reactions.add_reaction': REACTION_BUCKET,
|
||||
'channel_reactions.remove_own_reaction': REACTION_BUCKET,
|
||||
'channel_reactions.remove_user_reaction': REACTION_BUCKET,
|
||||
|
||||
'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')),
|
||||
'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')),
|
||||
|
||||
# this only applies to username.
|
||||
# 'users.patch_me': Ratelimit(2, 3600),
|
||||
|
||||
'_ws.connect': Ratelimit(1, 5),
|
||||
'_ws.presence': Ratelimit(5, 60),
|
||||
'_ws.messages': Ratelimit(120, 60),
|
||||
}
|
||||
|
||||
class RatelimitManager:
|
||||
"""Manager for the bucket managers"""
|
||||
def __init__(self):
|
||||
self._ratelimiters = {}
|
||||
self.global_bucket = Ratelimit(50, 1)
|
||||
self._fill_rtl()
|
||||
|
||||
def _fill_rtl(self):
|
||||
for path, rtl in RATELIMITS.items():
|
||||
self._ratelimiters[path] = rtl
|
||||
|
||||
def get_ratelimit(self, key: str) -> Ratelimit:
|
||||
"""Get the :class:`Ratelimit` instance for a given path."""
|
||||
return self._ratelimiters.get(key, self.global_bucket)
|
||||
|
|
|
|||
26
run.py
26
run.py
|
|
@ -9,14 +9,14 @@ from quart import Quart, g, jsonify, request
|
|||
from logbook import StreamHandler, Logger
|
||||
from logbook.compat import redirect_logging
|
||||
|
||||
# import the config set by instance owner
|
||||
import config
|
||||
|
||||
from litecord.blueprints import (
|
||||
gateway, auth, users, guilds, channels, webhooks, science,
|
||||
voice, invites, relationships, dms
|
||||
)
|
||||
|
||||
from litecord.ratelimits.main import ratelimit_handler
|
||||
|
||||
# those blueprints are separated from the "main" ones
|
||||
# for code readability if people want to dig through
|
||||
# the codebase.
|
||||
|
|
@ -28,6 +28,9 @@ from litecord.blueprints.channel import (
|
|||
channel_messages, channel_reactions
|
||||
)
|
||||
|
||||
from litecord.ratelimits.handler import ratelimit_handler
|
||||
from litecord.ratelimits.main import RatelimitManager
|
||||
|
||||
from litecord.gateway import websocket_handler
|
||||
from litecord.errors import LitecordError
|
||||
from litecord.gateway.state_manager import StateManager
|
||||
|
|
@ -110,6 +113,21 @@ async def app_after_request(resp):
|
|||
# resp.headers['Access-Control-Allow-Methods'] = '*'
|
||||
resp.headers['Access-Control-Allow-Methods'] = \
|
||||
resp.headers.get('allow', '*')
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
@app.after_request
|
||||
async def app_set_ratelimit_headers(resp):
|
||||
"""Set the specific ratelimit headers."""
|
||||
try:
|
||||
bucket = request.bucket
|
||||
resp.headers['X-RateLimit-Limit'] = str(bucket.requests)
|
||||
resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens)
|
||||
resp.headers['X-RateLimit-Reset'] = str(bucket._window + bucket.second)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
|
|
@ -123,6 +141,7 @@ async def app_before_serving():
|
|||
app.loop = asyncio.get_event_loop()
|
||||
g.loop = asyncio.get_event_loop()
|
||||
|
||||
app.ratelimiter = RatelimitManager()
|
||||
app.state_manager = StateManager()
|
||||
app.storage = Storage(app.db)
|
||||
|
||||
|
|
@ -141,7 +160,8 @@ async def app_before_serving():
|
|||
|
||||
# TODO: pass just the app object
|
||||
await websocket_handler((app.db, app.state_manager, app.storage,
|
||||
app.loop, app.dispatcher, app.presence),
|
||||
app.loop, app.dispatcher, app.presence,
|
||||
app.ratelimiter),
|
||||
ws, url)
|
||||
|
||||
ws_future = websockets.serve(_wrapper, host, port)
|
||||
|
|
|
|||
Loading…
Reference in New Issue