all: add ratelimit implementation

haven't tested yet, but it should work in theory.

 - gateway.websocket: add the 3 main ws ratelimits
 - litecord: add ratelimits package
 - ratelimits.main: add implementation
 - run: add app_set_ratelimit_headers
This commit is contained in:
Luna Mendes 2018-11-04 02:06:40 -03:00
parent b17cfd46eb
commit 33f893c0ff
6 changed files with 288 additions and 8 deletions

View File

@ -41,6 +41,10 @@ class MessageNotFound(LitecordError):
status_code = 404
class Ratelimited(LitecordError):
status_code = 429
class WebsocketClose(Exception):
@property
def code(self):

View File

@ -28,7 +28,8 @@ WebsocketProperties = collections.namedtuple(
)
WebsocketObjects = collections.namedtuple(
'WebsocketObjects', 'db state_manager storage loop dispatcher presence'
'WebsocketObjects', ('db', 'state_manager', 'storage',
'loop', 'dispatcher', 'presence', 'ratelimiter')
)
@ -138,6 +139,11 @@ class GatewayWebsocket:
else:
await self.ws.send(encoded.decode())
def _check_ratelimit(self, key: str, ratelimit_key: str):
ratelimit = self.ext.ratelimiter.get_ratelimit(f'_ws.{key}')
bucket = ratelimit.get_bucket(ratelimit_key)
return bucket.update_rate_limit()
async def _hb_wait(self, interval: int):
"""Wait heartbeat"""
# if the client heartbeats in time,
@ -342,6 +348,14 @@ class GatewayWebsocket:
async def update_status(self, status: dict):
"""Update the status of the current websocket connection."""
if not self.state:
return
if self._check_ratelimit('presence', self.state.session_id):
# Presence Updates beyond the ratelimit
# are just silently dropped.
return
if status is None:
status = {
'afk': False,
@ -395,6 +409,11 @@ class GatewayWebsocket:
'op': OP.HEARTBEAT_ACK,
})
async def _connect_ratelimit(self, user_id: int):
if self._check_ratelimit('connect', user_id):
await self.invalidate_session(False)
raise WebsocketClose(4009, 'You are being ratelimited.')
async def handle_2(self, payload: Dict[str, Any]):
"""Handle the OP 2 Identify packet."""
try:
@ -414,6 +433,8 @@ class GatewayWebsocket:
except (Unauthorized, Forbidden):
raise WebsocketClose(4004, 'Authentication failed')
await self._connect_ratelimit(user_id)
bot = await self.ext.db.fetchval("""
SELECT bot FROM users
WHERE id = $1
@ -751,6 +772,10 @@ class GatewayWebsocket:
await handler(payload)
async def _msg_ratelimit(self):
if self._check_ratelimit('messages', self.state.session_id):
raise WebsocketClose(4008, 'You are being ratelimited.')
async def listen_messages(self):
"""Listen for messages coming in from the websocket."""
@ -767,6 +792,9 @@ class GatewayWebsocket:
if len(message) > 4096:
raise DecodeError('Payload length exceeded')
if self.state:
await self._msg_ratelimit()
payload = self.decoder(message)
await self.process_message(payload)

View File

@ -0,0 +1,113 @@
"""
main litecord ratelimiting code
This code was copied from elixire's ratelimiting,
which in turn is a work on top of discord.py's ratelimiting.
"""
import time
class RatelimitBucket:
"""Main ratelimit bucket class."""
def __init__(self, tokens, second):
self.requests = tokens
self.second = second
self._window = 0.0
self._tokens = self.requests
self.retries = 0
self._last = 0.0
def get_tokens(self, current):
"""Get the current amount of available tokens."""
if not current:
current = time.time()
# by default, use _tokens
tokens = self._tokens
# if current timestamp is above _window + seconds
# reset tokens to self.requests (default)
if current > self._window + self.second:
tokens = self.requests
return tokens
def update_rate_limit(self):
"""Update current ratelimit state."""
current = time.time()
self._last = current
self._tokens = self.get_tokens(current)
# we are using the ratelimit for the first time
# so set current ratelimit window to right now
if self._tokens == self.requests:
self._window = current
# Are we currently ratelimited?
if self._tokens == 0:
self.retries += 1
return self.second - (current - self._window)
# if not ratelimited, remove a token
self.retries = 0
self._tokens -= 1
# if we got ratelimited after that token removal,
# set window to now
if self._tokens == 0:
self._window = current
def reset(self):
"""Reset current ratelimit to default state."""
self._tokens = self.requests
self._last = 0.0
self.retries = 0
def copy(self):
"""Create a copy of this ratelimit.
Used to manage multiple ratelimits to users.
"""
return RatelimitBucket(self.requests,
self.second)
def __repr__(self):
return (f'<RatelimitBucket requests={self.requests} '
f'second={self.second} window: {self._window} '
f'tokens={self._tokens}>')
class Ratelimit:
"""Manages buckets."""
def __init__(self, tokens, second, keys=None):
self._cache = {}
if keys is None:
keys = tuple()
self.keys = keys
self._cooldown = RatelimitBucket(tokens, second)
def __repr__(self):
return (f'<Ratelimit cooldown={self._cooldown}>')
def _verify_cache(self):
current = time.time()
dead_keys = [k for k, v in self._cache.items()
if current > v._last + v.second]
for k in dead_keys:
del self._cache[k]
def get_bucket(self, key) -> RatelimitBucket:
if not self._cooldown:
return None
self._verify_cache()
if key not in self._cache:
bucket = self._cooldown.copy()
self._cache[key] = bucket
else:
bucket = self._cache[key]
return bucket

View File

@ -0,0 +1,67 @@
from quart import current_app as app, request, g
from litecord.errors import Ratelimited
from litecord.auth import token_check, Unauthorized
async def _check_bucket(bucket):
retry_after = bucket.update_rate_limit()
request.bucket = bucket
if retry_after:
raise Ratelimited('You are being ratelimited.', {
'retry_after': retry_after
})
async def _handle_global(ratelimit):
"""Global ratelimit is per-user."""
try:
user_id = await token_check()
except Unauthorized:
user_id = request.remote_addr
bucket = ratelimit.get_bucket(user_id)
await _check_bucket(bucket)
async def _handle_specific(ratelimit):
try:
user_id = await token_check()
except Unauthorized:
user_id = request.remote_addr
# construct the key based on the ratelimit.keys
keys = ratelimit.keys
# base key is the user id
key_components = [f'user_id:{user_id}']
for key in keys:
val = request.view_args[key]
key_components.append(f'{key}:{val}')
bucket_key = ':'.join(key_components)
bucket = ratelimit.get_bucket(bucket_key)
await _check_bucket(bucket)
async def ratelimit_handler():
"""Main ratelimit handler.
Decides on which ratelimit to use.
"""
rule = request.url_rule
# rule.endpoint is composed of '<blueprint>.<function>'
# and so we can use that to make routes with different
# methods have different ratelimits
rule_path = rule.endpoint
try:
ratelimit = app.ratelimiter.get_ratelimit(rule_path)
await _handle_specific(ratelimit)
except KeyError:
ratelimit = app.ratelimiter.global_bucket
await _handle_global(ratelimit)

View File

@ -1,5 +1,53 @@
from quart import current_app as app, request
from litecord.ratelimits.bucket import Ratelimit
async def ratelimit_handler():
# dummy handler for future code
print(request.headers)
"""
REST:
POST Message | 5/5s | per-channel
DELETE Message | 5/1s | per-channel
PUT/DELETE Reaction | 1/0.25s | per-channel
PATCH Member | 10/10s | per-guild
PATCH Member Nick | 1/1s | per-guild
PATCH Username | 2/3600s | per-account
|All Requests| | 50/1s | per-account
WS:
Gateway Connect | 1/5s | per-account
Presence Update | 5/60s | per-session
|All Sent Messages| | 120/60s | per-session
"""
REACTION_BUCKET = Ratelimit(1, 0.25, ('channel_id'))
RATELIMITS = {
'channel_messages.create_message': Ratelimit(5, 5, ('channel_id')),
'channel_messages.delete_message': Ratelimit(5, 1, ('channel_id')),
# all of those share the same bucket.
'channel_reactions.add_reaction': REACTION_BUCKET,
'channel_reactions.remove_own_reaction': REACTION_BUCKET,
'channel_reactions.remove_user_reaction': REACTION_BUCKET,
'guild_members.modify_guild_member': Ratelimit(10, 10, ('guild_id')),
'guild_members.update_nickname': Ratelimit(1, 1, ('guild_id')),
# this only applies to username.
# 'users.patch_me': Ratelimit(2, 3600),
'_ws.connect': Ratelimit(1, 5),
'_ws.presence': Ratelimit(5, 60),
'_ws.messages': Ratelimit(120, 60),
}
class RatelimitManager:
"""Manager for the bucket managers"""
def __init__(self):
self._ratelimiters = {}
self.global_bucket = Ratelimit(50, 1)
self._fill_rtl()
def _fill_rtl(self):
for path, rtl in RATELIMITS.items():
self._ratelimiters[path] = rtl
def get_ratelimit(self, key: str) -> Ratelimit:
"""Get the :class:`Ratelimit` instance for a given path."""
return self._ratelimiters.get(key, self.global_bucket)

26
run.py
View File

@ -9,14 +9,14 @@ from quart import Quart, g, jsonify, request
from logbook import StreamHandler, Logger
from logbook.compat import redirect_logging
# import the config set by instance owner
import config
from litecord.blueprints import (
gateway, auth, users, guilds, channels, webhooks, science,
voice, invites, relationships, dms
)
from litecord.ratelimits.main import ratelimit_handler
# those blueprints are separated from the "main" ones
# for code readability if people want to dig through
# the codebase.
@ -28,6 +28,9 @@ from litecord.blueprints.channel import (
channel_messages, channel_reactions
)
from litecord.ratelimits.handler import ratelimit_handler
from litecord.ratelimits.main import RatelimitManager
from litecord.gateway import websocket_handler
from litecord.errors import LitecordError
from litecord.gateway.state_manager import StateManager
@ -110,6 +113,21 @@ async def app_after_request(resp):
# resp.headers['Access-Control-Allow-Methods'] = '*'
resp.headers['Access-Control-Allow-Methods'] = \
resp.headers.get('allow', '*')
return resp
@app.after_request
async def app_set_ratelimit_headers(resp):
"""Set the specific ratelimit headers."""
try:
bucket = request.bucket
resp.headers['X-RateLimit-Limit'] = str(bucket.requests)
resp.headers['X-RateLimit-Remaining'] = str(bucket._tokens)
resp.headers['X-RateLimit-Reset'] = str(bucket._window + bucket.second)
except AttributeError:
pass
return resp
@ -123,6 +141,7 @@ async def app_before_serving():
app.loop = asyncio.get_event_loop()
g.loop = asyncio.get_event_loop()
app.ratelimiter = RatelimitManager()
app.state_manager = StateManager()
app.storage = Storage(app.db)
@ -141,7 +160,8 @@ async def app_before_serving():
# TODO: pass just the app object
await websocket_handler((app.db, app.state_manager, app.storage,
app.loop, app.dispatcher, app.presence),
app.loop, app.dispatcher, app.presence,
app.ratelimiter),
ws, url)
ws_future = websockets.serve(_wrapper, host, port)