mirror of https://gitlab.com/litecord/litecord.git
all: add ratelimit implementation
haven't tested yet, but it should work in theory. - gateway.websocket: add the 3 main ws ratelimits - litecord: add ratelimits package - ratelimits.main: add implementation - run: add app_set_ratelimit_headers
This commit is contained in:
parent
b17cfd46eb
commit
33f893c0ff
|
|
@ -41,6 +41,10 @@ class MessageNotFound(LitecordError):
|
||||||
status_code = 404
|
status_code = 404
|
||||||
|
|
||||||
|
|
||||||
|
class Ratelimited(LitecordError):
|
||||||
|
status_code = 429
|
||||||
|
|
||||||
|
|
||||||
class WebsocketClose(Exception):
|
class WebsocketClose(Exception):
|
||||||
@property
|
@property
|
||||||
def code(self):
|
def code(self):
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,8 @@ WebsocketProperties = collections.namedtuple(
|
||||||
)
|
)
|
||||||
|
|
||||||
WebsocketObjects = collections.namedtuple(
|
WebsocketObjects = collections.namedtuple(
|
||||||
'WebsocketObjects', 'db state_manager storage loop dispatcher presence'
|
'WebsocketObjects', ('db', 'state_manager', 'storage',
|
||||||
|
'loop', 'dispatcher', 'presence', 'ratelimiter')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -138,6 +139,11 @@ class GatewayWebsocket:
|
||||||
else:
|
else:
|
||||||
await self.ws.send(encoded.decode())
|
await self.ws.send(encoded.decode())
|
||||||
|
|
||||||
|
def _check_ratelimit(self, key: str, ratelimit_key: str):
|
||||||
|
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
|
||||||
|
bucket = ratelimit.get_bucket(ratelimit_key)
|
||||||
|
return bucket.update_rate_limit()
|
||||||
|
|
||||||
async def _hb_wait(self, interval: int):
|
async def _hb_wait(self, interval: int):
|
||||||
"""Wait heartbeat"""
|
"""Wait heartbeat"""
|
||||||
# if the client heartbeats in time,
|
# if the client heartbeats in time,
|
||||||
|
|
@ -342,6 +348,14 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
async def update_status(self, status: dict):
|
async def update_status(self, status: dict):
|
||||||
"""Update the status of the current websocket connection."""
|
"""Update the status of the current websocket connection."""
|
||||||
|
if not self.state:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._check_ratelimit('presence', self.state.session_id):
|
||||||
|
# Presence Updates beyond the ratelimit
|
||||||
|
# are just silently dropped.
|
||||||
|
return
|
||||||
|
|
||||||
if status is None:
|
if status is None:
|
||||||
status = {
|
status = {
|
||||||
'afk': False,
|
'afk': False,
|
||||||
|
|
@ -395,6 +409,11 @@ class GatewayWebsocket:
|
||||||
'op': OP.HEARTBEAT_ACK,
|
'op': OP.HEARTBEAT_ACK,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
async def _connect_ratelimit(self, user_id: int):
|
||||||
|
if self._check_ratelimit('connect', user_id):
|
||||||
|
await self.invalidate_session(False)
|
||||||
|
raise WebsocketClose(4009, 'You are being ratelimited.')
|
||||||
|
|
||||||
async def handle_2(self, payload: Dict[str, Any]):
|
async def handle_2(self, payload: Dict[str, Any]):
|
||||||
"""Handle the OP 2 Identify packet."""
|
"""Handle the OP 2 Identify packet."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -414,6 +433,8 @@ class GatewayWebsocket:
|
||||||
except (Unauthorized, Forbidden):
|
except (Unauthorized, Forbidden):
|
||||||
raise WebsocketClose(4004, 'Authentication failed')
|
raise WebsocketClose(4004, 'Authentication failed')
|
||||||
|
|
||||||
|
await self._connect_ratelimit(user_id)
|
||||||
|
|
||||||
bot = await self.ext.db.fetchval("""
|
bot = await self.ext.db.fetchval("""
|
||||||
SELECT bot FROM users
|
SELECT bot FROM users
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
|
|
@ -751,6 +772,10 @@ class GatewayWebsocket:
|
||||||
|
|
||||||
await handler(payload)
|
await handler(payload)
|
||||||
|
|
||||||
|
async def _msg_ratelimit(self):
|
||||||
|
if self._check_ratelimit('messages', self.state.session_id):
|
||||||
|
raise WebsocketClose(4008, 'You are being ratelimited.')
|
||||||
|
|
||||||
async def listen_messages(self):
|
async def listen_messages(self):
|
||||||
"""Listen for messages coming in from the websocket."""
|
"""Listen for messages coming in from the websocket."""
|
||||||
|
|
||||||
|
|
@ -767,6 +792,9 @@ class GatewayWebsocket:
|
||||||
if len(message) > 4096:
|
if len(message) > 4096:
|
||||||
raise DecodeError('Payload length exceeded')
|
raise DecodeError('Payload length exceeded')
|
||||||
|
|
||||||
|
if self.state:
|
||||||
|
await self._msg_ratelimit()
|
||||||
|
|
||||||
payload = self.decoder(message)
|
payload = self.decoder(message)
|
||||||
await self.process_message(payload)
|
await self.process_message(payload)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
"""
|
||||||
|
main litecord ratelimiting code
|
||||||
|
|
||||||
|
This code was copied from elixire's ratelimiting,
|
||||||
|
which in turn is a work on top of discord.py's ratelimiting.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class RatelimitBucket:
|
||||||
|
"""Main ratelimit bucket class."""
|
||||||
|
def __init__(self, tokens, second):
|
||||||
|
self.requests = tokens
|
||||||
|
self.second = second
|
||||||
|
|
||||||
|
self._window = 0.0
|
||||||
|
self._tokens = self.requests
|
||||||
|
self.retries = 0
|
||||||
|
self._last = 0.0
|
||||||
|
|
||||||
|
def get_tokens(self, current):
|
||||||
|
"""Get the current amount of available tokens."""
|
||||||
|
if not current:
|
||||||
|
current = time.time()
|
||||||
|
|
||||||
|
# by default, use _tokens
|
||||||
|
tokens = self._tokens
|
||||||
|
|
||||||
|
# if current timestamp is above _window + seconds
|
||||||
|
# reset tokens to self.requests (default)
|
||||||
|
if current > self._window + self.second:
|
||||||
|
tokens = self.requests
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def update_rate_limit(self):
|
||||||
|
"""Update current ratelimit state."""
|
||||||
|
current = time.time()
|
||||||
|
self._last = current
|
||||||
|
self._tokens = self.get_tokens(current)
|
||||||
|
|
||||||
|
# we are using the ratelimit for the first time
|
||||||
|
# so set current ratelimit window to right now
|
||||||
|
if self._tokens == self.requests:
|
||||||
|
self._window = current
|
||||||
|
|
||||||
|
# Are we currently ratelimited?
|
||||||
|
if self._tokens == 0:
|
||||||
|
self.retries += 1
|
||||||
|
return self.second - (current - self._window)
|
||||||
|
|
||||||
|
# if not ratelimited, remove a token
|
||||||
|
self.retries = 0
|
||||||
|
self._tokens -= 1
|
||||||
|
|
||||||
|
# if we got ratelimited after that token removal,
|
||||||
|
# set window to now
|
||||||
|
if self._tokens == 0:
|
||||||
|
self._window = current
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset current ratelimit to default state."""
|
||||||
|
self._tokens = self.requests
|
||||||
|
self._last = 0.0
|
||||||
|
self.retries = 0
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
"""Create a copy of this ratelimit.
|
||||||
|
|
||||||
|
Used to manage multiple ratelimits to users.
|
||||||
|
"""
|
||||||
|
return RatelimitBucket(self.requests,
|
||||||
|
self.second)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f'<RatelimitBucket requests={self.requests} '
|
||||||
|
f'second={self.second} window: {self._window} '
|
||||||
|
f'tokens={self._tokens}>')
|
||||||
|
|
||||||
|
|
||||||
|
class Ratelimit:
|
||||||
|
"""Manages buckets."""
|
||||||
|
def __init__(self, tokens, second, keys=None):
|
||||||
|
self._cache = {}
|
||||||
|
if keys is None:
|
||||||
|
keys = tuple()
|
||||||
|
self.keys = keys
|
||||||
|
self._cooldown = RatelimitBucket(tokens, second)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f'<Ratelimit cooldown={self._cooldown}>')
|
||||||
|
|
||||||
|
def _verify_cache(self):
|
||||||
|
current = time.time()
|
||||||
|
dead_keys = [k for k, v in self._cache.items()
|
||||||
|
if current > v._last + v.second]
|
||||||
|
|
||||||
|
for k in dead_keys:
|
||||||
|
del self._cache[k]
|
||||||
|
|
||||||
|
def get_bucket(self, key) -> RatelimitBucket:
|
||||||
|
if not self._cooldown:
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._verify_cache()
|
||||||
|
|
||||||
|
if key not in self._cache:
|
||||||
|
bucket = self._cooldown.copy()
|
||||||
|
self._cache[key] = bucket
|
||||||
|
else:
|
||||||
|
bucket = self._cache[key]
|
||||||
|
|
||||||
|
return bucket
|
||||||
|
|
@ -0,0 +1,67 @@
|
||||||
|
from quart import current_app as app, request, g
|
||||||
|
|
||||||
|
from litecord.errors import Ratelimited
|
||||||
|
from litecord.auth import token_check, Unauthorized
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_bucket(bucket):
|
||||||
|
retry_after = bucket.update_rate_limit()
|
||||||
|
|
||||||
|
request.bucket = bucket
|
||||||
|
|
||||||
|
if retry_after:
|
||||||
|
raise Ratelimited('You are being ratelimited.', {
|
||||||
|
'retry_after': retry_after
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_global(ratelimit):
|
||||||
|
"""Global ratelimit is per-user."""
|
||||||
|
try:
|
||||||
|
user_id = await token_check()
|
||||||
|
except Unauthorized:
|
||||||
|
user_id = request.remote_addr
|
||||||
|
|
||||||
|
bucket = ratelimit.get_bucket(user_id)
|
||||||
|
await _check_bucket(bucket)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_specific(ratelimit):
|
||||||
|
try:
|
||||||
|
user_id = await token_check()
|
||||||
|
except Unauthorized:
|
||||||
|
user_id = request.remote_addr
|
||||||
|
|
||||||
|
# construct the key based on the ratelimit.keys
|
||||||
|
keys = ratelimit.keys
|
||||||
|
|
||||||
|
# base key is the user id
|
||||||
|
key_components = [f'user_id:{user_id}']
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
val = request.view_args[key]
|
||||||
|
key_components.append(f'{key}:{val}')
|
||||||
|
|
||||||
|
bucket_key = ':'.join(key_components)
|
||||||
|
bucket = ratelimit.get_bucket(bucket_key)
|
||||||
|
await _check_bucket(bucket)
|
||||||
|
|
||||||
|
|
||||||
|
async def ratelimit_handler():
|
||||||
|
"""Main ratelimit handler.
|
||||||
|
|
||||||
|
Decides on which ratelimit to use.
|
||||||
|
"""
|
||||||
|
rule = request.url_rule
|
||||||
|
|
||||||
|
# rule.endpoint is composed of '<blueprint>.<function>'
|
||||||
|
# and so we can use that to make routes with different
|
||||||
|
# methods have different ratelimits
|
||||||
|
rule_path = rule.endpoint
|
||||||
|
|
||||||
|
try:
|
||||||
|
ratelimit = app.ratelimiter.get_ratelimit(rule_path)
|
||||||
|
await _handle_specific(ratelimit)
|
||||||
|
except KeyError:
|
||||||
|
ratelimit = app.ratelimiter.global_bucket
|
||||||
|
await _handle_global(ratelimit)
|
||||||
|
|
@ -1,5 +1,53 @@
|
||||||
from quart import current_app as app, request
|
from litecord.ratelimits.bucket import Ratelimit
|
||||||
|
|
||||||
async def ratelimit_handler():
|
"""
|
||||||
# dummy handler for future code
|
REST:
|
||||||
print(request.headers)
|
POST Message | 5/5s | per-channel
|
||||||
|
DELETE Message | 5/1s | per-channel
|
||||||
|
PUT/DELETE Reaction | 1/0.25s | per-channel
|
||||||
|
PATCH Member | 10/10s | per-guild
|
||||||
|
PATCH Member Nick | 1/1s | per-guild
|
||||||
|
PATCH Username | 2/3600s | per-account
|
||||||
|
|All Requests| | 50/1s | per-account
|
||||||
|
WS:
|
||||||
|
Gateway Connect | 1/5s | per-account
|
||||||
|
Presence Update | 5/60s | per-session
|
||||||
|
|All Sent Messages| | 120/60s | per-session
|
||||||
|
"""
|
||||||
|
|
||||||
|
REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id'))
|
||||||
|
|
||||||
|
RATELIMITS = {
|
||||||
|
'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')),
|
||||||
|
'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')),
|
||||||
|
|
||||||
|
# all of those share the same bucket.
|
||||||
|
'channel_reactions.add_reaction': REACTION_BUCKET,
|
||||||
|
'channel_reactions.remove_own_reaction': REACTION_BUCKET,
|
||||||
|
'channel_reactions.remove_user_reaction': REACTION_BUCKET,
|
||||||
|
|
||||||
|
'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')),
|
||||||
|
'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')),
|
||||||
|
|
||||||
|
# this only applies to username.
|
||||||
|
# 'users.patch_me': Ratelimit(2, 3600),
|
||||||
|
|
||||||
|
'_ws.connect': Ratelimit(1, 5),
|
||||||
|
'_ws.presence': Ratelimit(5, 60),
|
||||||
|
'_ws.messages': Ratelimit(120, 60),
|
||||||
|
}
|
||||||
|
|
||||||
|
class RatelimitManager:
|
||||||
|
"""Manager for the bucket managers"""
|
||||||
|
def __init__(self):
|
||||||
|
self._ratelimiters = {}
|
||||||
|
self.global_bucket = Ratelimit(50, 1)
|
||||||
|
self._fill_rtl()
|
||||||
|
|
||||||
|
def _fill_rtl(self):
|
||||||
|
for path, rtl in RATELIMITS.items():
|
||||||
|
self._ratelimiters[path] = rtl
|
||||||
|
|
||||||
|
def get_ratelimit(self, key: str) -> Ratelimit:
|
||||||
|
"""Get the :class:`Ratelimit` instance for a given path."""
|
||||||
|
return self._ratelimiters.get(key, self.global_bucket)
|
||||||
|
|
|
||||||
26
run.py
26
run.py
|
|
@ -9,14 +9,14 @@ from quart import Quart, g, jsonify, request
|
||||||
from logbook import StreamHandler, Logger
|
from logbook import StreamHandler, Logger
|
||||||
from logbook.compat import redirect_logging
|
from logbook.compat import redirect_logging
|
||||||
|
|
||||||
|
# import the config set by instance owner
|
||||||
import config
|
import config
|
||||||
|
|
||||||
from litecord.blueprints import (
|
from litecord.blueprints import (
|
||||||
gateway, auth, users, guilds, channels, webhooks, science,
|
gateway, auth, users, guilds, channels, webhooks, science,
|
||||||
voice, invites, relationships, dms
|
voice, invites, relationships, dms
|
||||||
)
|
)
|
||||||
|
|
||||||
from litecord.ratelimits.main import ratelimit_handler
|
|
||||||
|
|
||||||
# those blueprints are separated from the "main" ones
|
# those blueprints are separated from the "main" ones
|
||||||
# for code readability if people want to dig through
|
# for code readability if people want to dig through
|
||||||
# the codebase.
|
# the codebase.
|
||||||
|
|
@ -28,6 +28,9 @@ from litecord.blueprints.channel import (
|
||||||
channel_messages, channel_reactions
|
channel_messages, channel_reactions
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from litecord.ratelimits.handler import ratelimit_handler
|
||||||
|
from litecord.ratelimits.main import RatelimitManager
|
||||||
|
|
||||||
from litecord.gateway import websocket_handler
|
from litecord.gateway import websocket_handler
|
||||||
from litecord.errors import LitecordError
|
from litecord.errors import LitecordError
|
||||||
from litecord.gateway.state_manager import StateManager
|
from litecord.gateway.state_manager import StateManager
|
||||||
|
|
@ -110,6 +113,21 @@ async def app_after_request(resp):
|
||||||
# resp.headers['Access-Control-Allow-Methods'] = '*'
|
# resp.headers['Access-Control-Allow-Methods'] = '*'
|
||||||
resp.headers['Access-Control-Allow-Methods'] = \
|
resp.headers['Access-Control-Allow-Methods'] = \
|
||||||
resp.headers.get('allow', '*')
|
resp.headers.get('allow', '*')
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@app.after_request
|
||||||
|
async def app_set_ratelimit_headers(resp):
|
||||||
|
"""Set the specific ratelimit headers."""
|
||||||
|
try:
|
||||||
|
bucket = request.bucket
|
||||||
|
resp.headers['X-RateLimit-Limit'] = str(bucket.requests)
|
||||||
|
resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens)
|
||||||
|
resp.headers['X-RateLimit-Reset'] = str(bucket._window + bucket.second)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -123,6 +141,7 @@ async def app_before_serving():
|
||||||
app.loop = asyncio.get_event_loop()
|
app.loop = asyncio.get_event_loop()
|
||||||
g.loop = asyncio.get_event_loop()
|
g.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
app.ratelimiter = RatelimitManager()
|
||||||
app.state_manager = StateManager()
|
app.state_manager = StateManager()
|
||||||
app.storage = Storage(app.db)
|
app.storage = Storage(app.db)
|
||||||
|
|
||||||
|
|
@ -141,7 +160,8 @@ async def app_before_serving():
|
||||||
|
|
||||||
# TODO: pass just the app object
|
# TODO: pass just the app object
|
||||||
await websocket_handler((app.db, app.state_manager, app.storage,
|
await websocket_handler((app.db, app.state_manager, app.storage,
|
||||||
app.loop, app.dispatcher, app.presence),
|
app.loop, app.dispatcher, app.presence,
|
||||||
|
app.ratelimiter),
|
||||||
ws, url)
|
ws, url)
|
||||||
|
|
||||||
ws_future = websockets.serve(_wrapper, host, port)
|
ws_future = websockets.serve(_wrapper, host, port)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue